[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher (#5168)

## Motivation

This PR adds CK Tile group convolution (forward, backward-data,
backward-weight) support to the kernel dispatcher, matching and unifying
with the existing dispatcher GEMM infrastructure in architecture and
usability. The dispatcher provides a unified kernel dispatch system with
both C++ and Python frontends, and until now only supported GEMM
operations. This PR enables framework integrators to use the same
declarative kernel workflow for convolutions as they do for GEMM:
declare kernels, build a registry JIT, select kernels within the
registry at runtime, and dispatch to GPU. Future PRs will include
runtime kernel selection heuristics for autotuning of kernel parameters
based on (problem, hardware arch).

## Technical Details

Grouped convolution support has been added to the CK Tile Dispatcher
with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out,
problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime
heuristic kernel selection, and GroupedConvKernelKey with full
ConvConfigBase fields. Python side adds parallel JIT via
registry.build(max_workers) and heuristic registry.select(). Includes 7
C++ and 6 Python examples covering all directions with CPU reference
validation, and shared infrastructure improvements (BaseRegistry CRTP,
structured exceptions). As a sanity check, JIT compile times for a
single kernel remains the same and for multiple kernels there is better
parallelism:
Kernels | 1 worker | 8 workers
1 | 7.7 s | 7.7 s
2 | 15.9 s | 8.2 s
4 | 33.4 s | 9.7 s
6 | 52.3 s | 10.2 s

## Test Plan

145 ephemeral unit tests have been added to test basic functionality.
All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7
C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference
validation for forward, backward-data, and backward-weight (2D) in both
C++ and Python examples pass.

## Test Result

30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56),
53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002
for all directions (fp16 vs fp32 reference).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Yaswanth Raparti <113389104+yraparti@users.noreply.github.com>
This commit is contained in:
Vidyasagar Ananthan
2026-04-09 10:38:33 -07:00
committed by GitHub
parent 65ad35becd
commit 40290297cd
86 changed files with 15538 additions and 1500 deletions

View File

@@ -1,6 +1,6 @@
# CK Tile Dispatcher
A unified kernel dispatch system for AMD GPUs with C++ and Python frontends.
A unified kernel dispatch system for AMD GPUs with C++ and Python frontends, supporting GEMM and Grouped Convolution operations.
**Validated Platform:** AMD Instinct MI300 series (gfx942)
@@ -342,8 +342,8 @@ ls examples/libdispatcher_gemm_lib.so
| `CMAKE_PREFIX_PATH` | - | ROCm installation path |
| `CMAKE_CXX_COMPILER` | - | Path to hipcc compiler |
⚠️ **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower.
⚠️ **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories).
WARNING: **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower.
WARNING: **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories).
---
@@ -363,6 +363,15 @@ cd build/examples
./gemm_04_heuristics # Heuristic kernel selection
./gemm_05_json_export # Registry JSON export
./gemm_06_multi_registry # Multiple registries
# Grouped Convolution Examples
./grouped_conv_01_basic # Declaration patterns + GPU execution
./grouped_conv_02_all_dirs # Forward/BwdData/BwdWeight with GPU
./grouped_conv_03_bench_val # Benchmark + CPU reference validation
./grouped_conv_04_registry_json # Heuristic selection + JSON export
./grouped_conv_05_bwd_data # Backward data + CPU validation
./grouped_conv_06_bwd_weight # Backward weight + CPU validation
./grouped_conv_07_benchmark # Multi-tile ResNet benchmark
```
### Python Examples
@@ -375,8 +384,16 @@ cd /path/to/composable_kernel/dispatcher
# GEMM Examples
python3 examples/gemm/python/01_basic_gemm.py # Basic multi-kernel GEMM
python3 examples/gemm/python/04_validation.py # CPU reference validation
python3 examples/gemm/python/07_stress_test.py # Stress test (48 kernels)
python3 examples/gemm/python/07_stress_test.py # Stress test
python3 examples/gemm/python/08_heuristics.py # Heuristic selection
# Grouped Convolution Examples
python3 examples/grouped_conv/python/01_basic_grouped_conv.py # Config patterns + registry + GPU
python3 examples/grouped_conv/python/02_forward.py # Forward 2D/3D + CPU ref
python3 examples/grouped_conv/python/03_bwd_data.py # Backward data + CPU ref
python3 examples/grouped_conv/python/04_bwd_weight.py # Backward weight + CPU ref
python3 examples/grouped_conv/python/05_benchmark.py # Multi-problem benchmark
python3 examples/grouped_conv/python/06_registry_json.py # Heuristic selection + JSON
```
### Example Output
@@ -647,7 +664,7 @@ lib = DispatcherLib.load("/absolute/path/to/libdispatcher_gemm_lib.so")
### Data Flow
```
KernelConfig Registry Dispatcher GPU Execution
KernelConfig -> Registry -> Dispatcher -> GPU Execution
```
1. **KernelConfig**: Defines kernel parameters (tile sizes, data types, layouts)
@@ -843,31 +860,49 @@ make -j$(nproc)
```
dispatcher/
├── README.md # This file
├── CMakeLists.txt # Build configuration
├── include/ck_tile/dispatcher/ # C++ headers
├── dispatcher.hpp # GEMM dispatcher
├── registry.hpp # Kernel registry
└── kernel_key.hpp # Kernel configuration
├── src/ # C++ implementation
├── codegen/ # Kernel generation
├── unified_gemm_codegen.py # GEMM kernel generator
│ └── arch_specs.json # GPU specifications
├── bindings/ctypes/ # Python ctypes interface
│ └── gemm_ctypes_lib.cpp # GEMM Python library
├── examples/ # Examples
│ └── gemm/
│ ├── cpp/ # C++ GEMM examples (01-06)
│ └── python/ # Python GEMM examples (01-11)
├── scripts/ # Build scripts
└── tests/ # Unit tests
|---- README.md # This file
|---- CMakeLists.txt # Build configuration
|
|---- include/ck_tile/dispatcher/ # C++ headers
| |---- dispatcher.hpp # Main dispatcher include
| |---- registry.hpp # GEMM kernel registry
| |---- kernel_key.hpp # Kernel configuration
| |---- grouped_conv_config.hpp # Grouped conv configuration
| |---- grouped_conv_problem.hpp # Grouped conv problem (with builder)
| |---- grouped_conv_kernel_decl.hpp # Grouped conv kernel declarations
| |---- grouped_conv_registry.hpp # Grouped conv registry (thread-safe)
| +---- grouped_conv_utils.hpp # Grouped conv utilities
|
|---- src/ # C++ implementation
|
|---- codegen/ # Kernel generation
| |---- codegen_common.py # Shared: TileConfig, TraitConfigBase, type mappings
| |---- unified_gemm_codegen.py # GEMM kernel generator
| |---- unified_grouped_conv_codegen.py # Grouped conv kernel generator
| +---- arch_specs.json # GPU specifications
|
|---- python/ # Python utilities
| |---- dispatcher_common.py # Shared: paths, validation, Colors, phased output
| |---- ctypes_utils.py # GEMM ctypes utilities
| +---- grouped_conv_utils.py # Grouped conv utilities
|
|---- scripts/ # Build scripts
| |---- compile_gemm_examples.py # GEMM build script
| +---- compile_grouped_conv_examples.py # Grouped conv build script
|
|---- bindings/ctypes/ # Python ctypes interface
| |---- gemm_ctypes_lib.cpp # GEMM Python library
| +---- conv_ctypes_lib.cpp # Grouped conv Python library
|
|---- examples/ # Examples
| |---- gemm/
| | |---- cpp/ # C++ GEMM examples (01-07)
| | +---- python/ # Python GEMM examples (01-11)
| +---- grouped_conv/
| |---- cpp/ # C++ Grouped Conv examples (01-07)
| +---- python/ # Python Grouped Conv examples (01-06)
|
+---- tests/ # Unit tests (C++ and Python)
```
---
@@ -879,17 +914,49 @@ dispatcher/
| GEMM C++ | [examples/gemm/cpp/README.md](examples/gemm/cpp/README.md) |
| GEMM Python | [examples/gemm/python/README.md](examples/gemm/python/README.md) |
| Codegen | [codegen/README.md](codegen/README.md) |
| Python Utils | [python/README.md](python/README.md) |
| C++ Headers | [include/ck_tile/dispatcher/README.md](include/ck_tile/dispatcher/README.md) |
---
## Archived Content
## Grouped Convolution Support
Convolution examples and utilities have been archived to `ck-2/conv_archive/dispatcher/`:
- `examples/conv/cpp/` - 11 C++ convolution examples
- `examples/conv/python/` - 14 Python convolution examples
- `codegen/unified_conv_codegen.py` - Conv kernel generator
- `include/ck_tile/dispatcher/conv_*.hpp` - Conv headers
- `python/conv_utils.py` - Conv Python utilities
Grouped convolution is fully supported alongside GEMM, with shared infrastructure to eliminate duplication.
### Python
```bash
# Generate grouped conv kernels
python3 codegen/unified_grouped_conv_codegen.py \
--output-dir build/generated_kernels \
--datatype fp16 --variant forward --ndim-spatial 2
# Build grouped conv examples
python3 scripts/compile_grouped_conv_examples.py examples/grouped_conv/cpp/01_basic_grouped_conv.cpp
```
### Key Files
| Component | File |
|-----------|------|
| C++ Headers | `include/ck_tile/dispatcher/grouped_conv_*.hpp` |
| Python Codegen | `codegen/unified_grouped_conv_codegen.py` |
| Python Utils | `python/grouped_conv_utils.py` |
| Build Script | `scripts/compile_grouped_conv_examples.py` |
| Shared Codegen | `codegen/codegen_common.py` |
| Shared Utils | `python/dispatcher_common.py` |
### Variants
- **Forward** (`grouped_conv_fwd`) - Standard grouped convolution
- **Backward Data** (`grouped_conv_bwd_data`) - Gradient w.r.t. input
- **Backward Weight** (`grouped_conv_bwd_weight`) - Gradient w.r.t. weights
### Shared Infrastructure
GEMM and grouped convolution share common code to avoid duplication:
- `codegen/codegen_common.py` - TileConfig, TraitConfigBase, type mappings, parallel generation, arch-aware expansion
- `python/dispatcher_common.py` - Path helpers, validation, auto-correction, Colors, phased output
---

View File

@@ -6,13 +6,13 @@ This directory contains language bindings for the CK Tile Dispatcher.
```
bindings/
├── ctypes/ # Python ctypes bindings (C API)
├── gemm_ctypes_lib.cpp # GEMM dispatcher C API
├── conv_ctypes_lib.cpp # Convolution dispatcher C API (fwd + bwd_data)
├── conv_bwdw_ctypes_lib.cpp # Convolution backward weight C API
├── gpu_helper.cpp # CLI helper for Python
└── CMakeLists.txt
└── README.md
|---- ctypes/ # Python ctypes bindings (C API)
| |---- gemm_ctypes_lib.cpp # GEMM dispatcher C API
| |---- conv_ctypes_lib.cpp # Grouped conv dispatcher C API (fwd + bwd_data)
| |---- conv_bwdw_ctypes_lib.cpp # Grouped conv backward weight C API (separate library)
| |---- gpu_helper.cpp # CLI helper for Python
| +---- CMakeLists.txt
+---- README.md
```
## ctypes Bindings
@@ -65,7 +65,7 @@ lib.dispatcher_cleanup()
| `dispatcher_export_registry_json()` | Export registry as JSON |
| `dispatcher_cleanup()` | Release resources |
### Convolution API
### Grouped Convolution API
| Function | Description |
|----------|-------------|
@@ -105,5 +105,11 @@ Output is JSON for easy parsing:
See the examples that use these bindings:
- **GEMM**: `dispatcher/examples/gemm/python/`
- **Conv**: `dispatcher/examples/conv/python/`
### Grouped Convolution
Grouped convolution C++ headers and Python utilities are in:
- **C++ Headers**: `dispatcher/include/ck_tile/dispatcher/grouped_conv_*.hpp`
- **Python Utils**: `dispatcher/python/grouped_conv_utils.py`
- **Build Script**: `dispatcher/scripts/compile_grouped_conv_examples.py`

View File

@@ -78,7 +78,7 @@ endif()
# Look for forward kernels
file(GLOB CONV_FWD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_fwd_*.hpp")
# Look for backward data kernels
file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwdd_*.hpp")
file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwd_data_*.hpp")
# Fallback: any conv kernel (for backwards compatibility)
file(GLOB CONV_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_*.hpp")
@@ -112,7 +112,7 @@ endif()
# Add backward data kernel if available
if(CONV_BWDD_KERNEL_HEADERS)
list(GET CONV_BWDD_KERNEL_HEADERS 0 CONV_BWDD_KERNEL_HEADER)
message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWDD_KERNEL_HEADER}")
message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWD_DATA_KERNEL_HEADER}")
target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_BWDD_KERNEL_HEADER})
target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_BWD_DATA_AVAILABLE)
endif()

View File

@@ -53,6 +53,7 @@ struct ConvBwdwProblemC
int stride_d, stride_h, stride_w;
int pad_d, pad_h, pad_w;
int dilation_d, dilation_h, dilation_w;
int split_k;
};
// =============================================================================
@@ -108,8 +109,7 @@ static float run_bwd_weight_impl(const void* input_ptr,
grad_weight_ptr, // wei_ptr = grad_weight (output)
{}, // ds_ptr
grad_output_ptr, // out_ptr = grad_output
1 // k_batch
);
(prob->split_k > 1) ? prob->split_k : 1);
ck_tile::stream_config stream_cfg{static_cast<hipStream_t>(stream), true, 1, 3, 10};

View File

@@ -1,128 +1,46 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* Convolution Dispatcher ctypes Library
*
* Provides C API for Python ctypes integration.
* Supports forward convolution. Backward operations require additional headers.
*
* REQUIRED: Forward kernel header must be force-included via -include flag.
* OPTIONAL: Backward kernels can be added with CONV_BWD_DATA_AVAILABLE/CONV_BWD_WEIGHT_AVAILABLE
*
* Usage from Python:
* lib = ctypes.CDLL("libdispatcher_conv.so")
* lib.conv_dispatcher_init()
* lib.conv_dispatcher_run(...)
*/
//
// Multi-kernel grouped convolution dispatcher for Python ctypes.
//
// Supports: forward / backward-data / backward-weight x 2D / 3D
//
// The dispatch header (conv_python_dispatch.hpp) is force-included via
// -include and brings in ALL compiled kernels with these aliases:
//
// 2D launchers (from include_all headers):
// SelectedConvKernelLauncher (forward 2D)
// SelectedConvBwdDataLauncher (backward-data 2D)
// SelectedConvBwdWeightLauncher (backward-weight 2D)
//
// 3D launchers (from dispatch header):
// ConvFwd3dLauncher (forward 3D)
// ConvBwdData3dLauncher (backward-data 3D)
// ConvBwdWeight3dLauncher (backward-weight 3D)
//
// Usage from Python:
// lib = ctypes.CDLL("libdispatcher_conv_lib.so")
// lib.conv_dispatcher_init()
// lib.conv_dispatcher_run(input, weight, output, &problem, stream)
#include <cstring>
#include <memory>
#include <vector>
#include <stdexcept>
#include <hip/hip_runtime.h>
#include "ck_tile/dispatcher/conv_utils.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
using namespace ck_tile::dispatcher;
// Global state (using shared_ptr for safe memory management)
static std::shared_ptr<ConvRegistry> g_registry = nullptr;
static std::shared_ptr<ConvDispatcher> g_dispatcher = nullptr;
static std::vector<const ConvKernelInstance*> g_kernels;
extern "C" {
// =============================================================================
// Initialization
// =============================================================================
int conv_dispatcher_init()
// =========================================================================
// Problem definition (matches Python ctypes struct exactly)
// =========================================================================
enum ConvDirection
{
if(g_registry)
return 0; // Already initialized
g_registry = std::make_shared<ConvRegistry>();
g_dispatcher = std::make_shared<ConvDispatcher>(g_registry.get());
// Register kernel configurations using simple ConvKernelSet
// (actual kernel launch uses the force-included SelectedConvKernelLauncher)
using namespace ck_tile::dispatcher::conv_decl;
// Forward kernels (required - must be force-included)
// Must match: conv_fwd_fp16_nhwgc_2d_compv4_cshuffle_intrawave_128x128x64_2x2x1_32x32x16_dsb
ConvKernelSet fwd_set;
fwd_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
ConvAlgorithm()
.tile(128, 128, 64) // tile_m x tile_n x tile_k
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv4")
.scheduler("intrawave"),
"gfx942");
g_registry->register_set(fwd_set, ConvRegistry::Priority::High);
#ifdef CONV_BWD_DATA_AVAILABLE
// Backward data kernels
// Must match: conv_bwdd_fp16_nhwgc_2d_compv3_cshuffle_intrawave_128x128x64_2x2x1_32x32x16
ConvKernelSet bwd_data_set;
bwd_data_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2),
ConvAlgorithm()
.tile(128, 128, 64) // tile_m x tile_n x tile_k
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave"),
"gfx942");
g_registry->register_set(bwd_data_set, ConvRegistry::Priority::High);
#endif
return 0;
}
int conv_dispatcher_cleanup()
{
// shared_ptr automatically handles cleanup when reset
g_dispatcher.reset();
g_registry.reset();
g_kernels.clear();
return 0;
}
// =============================================================================
// Registry Management
// =============================================================================
int conv_dispatcher_get_kernel_count()
{
if(!g_registry)
return 0;
return static_cast<int>(g_registry->size());
}
int conv_dispatcher_get_kernel_name(int index, char* buffer, int buffer_size)
{
if(index < 0 || !buffer || buffer_size <= 0)
return -1;
if(!g_registry)
return -1;
// Use registry to get kernel names (they are registered with full names)
const auto& kernels = g_registry->all_kernels();
if(static_cast<size_t>(index) >= kernels.size())
return -1;
const auto* kernel = kernels[index];
std::strncpy(buffer, kernel->name().c_str(), buffer_size - 1);
buffer[buffer_size - 1] = '\0';
return 0;
}
// =============================================================================
// Problem Definition
// =============================================================================
CONV_FORWARD = 0,
CONV_BWD_DATA = 1,
CONV_BWD_WEIGHT = 2
};
struct ConvProblemC
{
@@ -132,267 +50,33 @@ struct ConvProblemC
int stride_d, stride_h, stride_w;
int pad_d, pad_h, pad_w;
int dilation_d, dilation_h, dilation_w;
int direction; // 0=forward, 1=bwd_data, 2=bwd_weight
int direction;
int split_k;
};
// =============================================================================
// Kernel Selection
// =============================================================================
// =========================================================================
// Initialization / lifecycle
// =========================================================================
int conv_dispatcher_init() { return 0; }
int conv_dispatcher_cleanup() { return 0; }
int conv_dispatcher_is_supported(const ConvProblemC* prob)
{
if(!g_registry || !prob)
return 0;
ConvProblem problem;
problem.N = prob->N;
problem.G = prob->G;
problem.C = prob->C;
problem.K = prob->K;
problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w};
problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x};
problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w};
problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w};
problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w};
problem.op = static_cast<ConvOp>(prob->direction);
problem.compute_output_size();
const auto* kernel = g_dispatcher->select(problem);
return kernel ? 1 : 0;
}
int conv_dispatcher_select_kernel(const ConvProblemC* prob, char* kernel_name, int buffer_size)
{
if(!g_registry || !prob || !kernel_name || buffer_size <= 0)
return -1;
ConvProblem problem;
problem.N = prob->N;
problem.G = prob->G;
problem.C = prob->C;
problem.K = prob->K;
problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w};
problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x};
problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w};
problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w};
problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w};
problem.op = static_cast<ConvOp>(prob->direction);
problem.compute_output_size();
const auto* kernel = g_dispatcher->select(problem);
if(!kernel)
return -1;
std::strncpy(kernel_name, kernel->name().c_str(), buffer_size - 1);
kernel_name[buffer_size - 1] = '\0';
return 0;
}
// =============================================================================
// Convolution Execution
// =============================================================================
// Helper to build ConvParam
static ck_tile::conv::ConvParam build_conv_param(const ConvProblemC* prob)
{
// Determine if this is 2D or 3D convolution
const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1);
if(is_3d)
{
// 3D convolution: use all spatial dimensions
return ck_tile::conv::ConvParam{3,
prob->G,
prob->N,
prob->K,
prob->C,
{prob->filter_z, prob->filter_y, prob->filter_x},
{prob->input_d, prob->input_h, prob->input_w},
{prob->stride_d, prob->stride_h, prob->stride_w},
{prob->dilation_d, prob->dilation_h, prob->dilation_w},
{prob->pad_d, prob->pad_h, prob->pad_w},
{prob->pad_d, prob->pad_h, prob->pad_w}};
}
else
{
// 2D convolution: only use H, W dimensions
return ck_tile::conv::ConvParam{2,
prob->G,
prob->N,
prob->K,
prob->C,
{prob->filter_y, prob->filter_x},
{prob->input_h, prob->input_w},
{prob->stride_h, prob->stride_w},
{prob->dilation_h, prob->dilation_w},
{prob->pad_h, prob->pad_w},
{prob->pad_h, prob->pad_w}};
}
}
// Forward convolution (required - kernel header must be force-included)
static float run_forward(const void* input_ptr,
const void* weight_ptr,
void* output_ptr,
const ConvProblemC* prob,
void* stream)
{
auto conv_param = build_conv_param(prob);
ck_tile::GroupedConvFwdHostArgs<> args(conv_param, input_ptr, weight_ptr, {}, output_ptr, 1);
ck_tile::stream_config stream_cfg{static_cast<hipStream_t>(stream), true, 1, 3, 10};
// SelectedConvKernelLauncher is defined in the force-included forward kernel header
return SelectedConvKernelLauncher::launch(args, stream_cfg);
}
#ifdef CONV_BWD_DATA_AVAILABLE
// Backward data convolution (optional)
// Computes: grad_input = conv_bwd_data(weight, grad_output)
//
// Parameters:
// grad_output_ptr: dY - gradient from next layer (const, read-only INPUT)
// weight_ptr: W - frozen weights (const, read-only INPUT)
// grad_input_ptr: dX - gradient for input (writable, OUTPUT)
static float run_bwd_data(const void* grad_output_ptr,
const void* weight_ptr,
void* grad_input_ptr,
const ConvProblemC* prob,
void* stream)
{
auto conv_param = build_conv_param(prob);
// CK Tile API uses tensor POSITION names (from forward pass), not data flow:
// in_ptr = input tensor position = grad_input_ptr (dX, OUTPUT of bwd_data)
// wei_ptr = weight tensor = weight_ptr (W, const)
// out_ptr = output tensor position = grad_output_ptr (dY, INPUT to bwd_data)
ck_tile::GroupedConvBwdDataHostArgs args(
conv_param, grad_input_ptr, weight_ptr, {}, grad_output_ptr, 1);
ck_tile::stream_config stream_cfg{static_cast<hipStream_t>(stream), true, 1, 3, 10};
return SelectedConvBwdDataLauncher::launch(args, stream_cfg);
}
#endif
#ifdef CONV_BWD_WEIGHT_AVAILABLE
// Backward weight convolution (optional)
// Parameters:
// input_ptr: original forward input X (const, read-only)
// grad_output_ptr: gradient from next layer dY (const, read-only)
// grad_weight_ptr: gradient of weights dW (writable, OUTPUT)
static float run_bwd_weight(const void* input_ptr,
const void* grad_output_ptr,
void* grad_weight_ptr,
const ConvProblemC* prob,
void* stream)
{
auto conv_param = build_conv_param(prob);
// GroupedConvBwdWeightHostArgs constructor order:
// (param, in=X, wei=dW (output), ds, out=dY (input), k_batch)
// Note: wei_ptr is the OUTPUT (grad_weight), out_ptr is the INPUT (grad_output)
ck_tile::GroupedConvBwdWeightHostArgs args(
conv_param, input_ptr, grad_weight_ptr, {}, grad_output_ptr, 1);
ck_tile::stream_config stream_cfg{static_cast<hipStream_t>(stream), true, 1, 3, 10};
return SelectedConvBwdWeightLauncher::launch(args, stream_cfg);
}
#endif
/**
* @brief Execute convolution based on direction specified in prob
*
* Parameter mapping varies by direction:
* Forward (direction=0):
* input_ptr = X (input tensor)
* weight_ptr = W (weight tensor)
* output_ptr = Y (output buffer)
*
* Backward Data (direction=1):
* input_ptr = dY (grad_output - gradient from next layer)
* weight_ptr = W (weight tensor, frozen)
* output_ptr = dX (grad_input buffer)
*
* Backward Weight (direction=2):
* input_ptr = X (forward input tensor)
* weight_ptr = dY (grad_output - gradient from next layer)
* output_ptr = dW (grad_weight buffer)
*/
float conv_dispatcher_run(const void* input_ptr,
const void* weight_ptr,
void* output_ptr,
const ConvProblemC* prob,
void* stream)
{
// Validate all required pointers before kernel launch
if(!g_dispatcher || !prob)
return -1.0f;
if(!input_ptr || !weight_ptr || !output_ptr)
return -1.0f; // Null data pointer would cause kernel crash
// Build problem for kernel selection
ConvProblem problem;
problem.N = prob->N;
problem.G = prob->G;
problem.C = prob->C;
problem.K = prob->K;
problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w};
problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x};
problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w};
problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w};
problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w};
problem.op = static_cast<ConvOp>(prob->direction);
problem.compute_output_size();
// Select kernel
const auto* kernel = g_dispatcher->select(problem);
if(!kernel)
return -1.0f;
// Dispatch based on direction
switch(prob->direction)
{
case 0: // Forward (always available)
return run_forward(input_ptr, weight_ptr, output_ptr, prob, stream);
#ifdef CONV_BWD_DATA_AVAILABLE
case 1: // Backward data
// Convention: caller passes (grad_output, weight, grad_input_buffer)
// in the (input_ptr, weight_ptr, output_ptr) slots respectively.
// run_bwd_data expects: (grad_output, weight, grad_input)
return run_bwd_data(input_ptr, weight_ptr, output_ptr, prob, stream);
#endif
#ifdef CONV_BWD_WEIGHT_AVAILABLE
case 2: // Backward weight
// Convention: caller passes (input, grad_output, grad_weight_buffer)
// in the (input_ptr, weight_ptr, output_ptr) slots respectively.
// run_bwd_weight expects: (input, grad_output, grad_weight)
return run_bwd_weight(input_ptr, weight_ptr, output_ptr, prob, stream);
#endif
default: return -1.0f;
}
}
// =============================================================================
// Info
// =============================================================================
const char* conv_dispatcher_version() { return "1.0.0"; }
// =========================================================================
// Library info
// =========================================================================
const char* conv_dispatcher_version() { return "2.0.0"; }
int conv_dispatcher_has_kernels()
{
return 1; // Forward kernel is required
#if defined(CONV_FWD_2D_AVAILABLE) || defined(CONV_FWD_3D_AVAILABLE)
return 1;
#else
return 0;
#endif
}
int conv_dispatcher_has_bwd_data()
{
#ifdef CONV_BWD_DATA_AVAILABLE
#if defined(CONV_BWD_DATA_2D_AVAILABLE) || defined(CONV_BWD_DATA_3D_AVAILABLE)
return 1;
#else
return 0;
@@ -401,11 +85,240 @@ int conv_dispatcher_has_bwd_data()
int conv_dispatcher_has_bwd_weight()
{
#ifdef CONV_BWD_WEIGHT_AVAILABLE
#if defined(CONV_BWD_WEIGHT_2D_AVAILABLE) || defined(CONV_BWD_WEIGHT_3D_AVAILABLE)
return 1;
#else
return 0;
#endif
}
int conv_dispatcher_get_kernel_count()
{
return CONV_KERNEL_COUNT; // defined in conv_python_dispatch.hpp
}
int conv_dispatcher_get_kernel_name(int index, char* buffer, int buffer_size)
{
if(!buffer || buffer_size <= 0 || index < 0 || index >= CONV_KERNEL_COUNT)
return -1;
std::strncpy(buffer, CONV_KERNEL_NAMES[index], buffer_size - 1);
buffer[buffer_size - 1] = '\0';
return 0;
}
// =========================================================================
// Support query
// =========================================================================
bool conv_dispatcher_is_supported(const ConvProblemC* prob)
{
if(!prob)
return false;
const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1);
switch(prob->direction)
{
case CONV_FORWARD:
#if defined(CONV_FWD_3D_AVAILABLE)
if(is_3d)
return true;
#endif
#if defined(CONV_FWD_2D_AVAILABLE)
if(!is_3d)
return true;
#endif
return false;
case CONV_BWD_DATA:
#if defined(CONV_BWD_DATA_3D_AVAILABLE)
if(is_3d)
return true;
#endif
#if defined(CONV_BWD_DATA_2D_AVAILABLE)
if(!is_3d)
return true;
#endif
return false;
case CONV_BWD_WEIGHT:
#if defined(CONV_BWD_WEIGHT_3D_AVAILABLE)
if(is_3d)
return true;
#endif
#if defined(CONV_BWD_WEIGHT_2D_AVAILABLE)
if(!is_3d)
return true;
#endif
return false;
default: return false;
}
}
// =========================================================================
// ConvParam builders
// =========================================================================
static ck_tile::conv::ConvParam make_param_2d(const ConvProblemC* p)
{
return ck_tile::conv::ConvParam{2,
p->G,
p->N,
p->K,
p->C,
{p->filter_y, p->filter_x},
{p->input_h, p->input_w},
{p->stride_h, p->stride_w},
{p->dilation_h, p->dilation_w},
{p->pad_h, p->pad_w},
{p->pad_h, p->pad_w}};
}
static ck_tile::conv::ConvParam make_param_3d(const ConvProblemC* p)
{
return ck_tile::conv::ConvParam{3,
p->G,
p->N,
p->K,
p->C,
{p->filter_z, p->filter_y, p->filter_x},
{p->input_d, p->input_h, p->input_w},
{p->stride_d, p->stride_h, p->stride_w},
{p->dilation_d, p->dilation_h, p->dilation_w},
{p->pad_d, p->pad_h, p->pad_w},
{p->pad_d, p->pad_h, p->pad_w}};
}
// =========================================================================
// Kernel launch helpers
// =========================================================================
#ifdef CONV_FWD_2D_AVAILABLE
static float
launch_fwd_2d(const void* in, const void* wei, void* out, const ConvProblemC* p, hipStream_t stream)
{
auto param = make_param_2d(p);
ck_tile::GroupedConvFwdHostArgs<> args(param, in, wei, {}, out, 1);
ck_tile::stream_config sc{stream, true, 1, 3, 10};
return SelectedConvKernelLauncher::launch(args, sc);
}
#endif
#ifdef CONV_FWD_3D_AVAILABLE
static float
launch_fwd_3d(const void* in, const void* wei, void* out, const ConvProblemC* p, hipStream_t stream)
{
auto param = make_param_3d(p);
ck_tile::GroupedConvFwdHostArgs<> args(param, in, wei, {}, out, 1);
ck_tile::stream_config sc{stream, true, 1, 3, 10};
return ConvFwd3dLauncher::launch(args, sc);
}
#endif
#ifdef CONV_BWD_DATA_2D_AVAILABLE
static float launch_bwd_data_2d(
const void* dy, const void* wei, void* dx, const ConvProblemC* p, hipStream_t stream)
{
auto param = make_param_2d(p);
ck_tile::GroupedConvBwdDataHostArgs args(param, dx, wei, {}, dy, 1);
ck_tile::stream_config sc{stream, true, 1, 3, 10};
return SelectedConvBwdDataLauncher::launch(args, sc);
}
#endif
#ifdef CONV_BWD_DATA_3D_AVAILABLE
static float launch_bwd_data_3d(
const void* dy, const void* wei, void* dx, const ConvProblemC* p, hipStream_t stream)
{
auto param = make_param_3d(p);
ck_tile::GroupedConvBwdDataHostArgs args(param, dx, wei, {}, dy, 1);
ck_tile::stream_config sc{stream, true, 1, 3, 10};
return ConvBwdData3dLauncher::launch(args, sc);
}
#endif
#ifdef CONV_BWD_WEIGHT_2D_AVAILABLE
static float launch_bwd_weight_2d(
const void* x, const void* dy, void* dw, const ConvProblemC* p, hipStream_t stream)
{
auto param = make_param_2d(p);
const int k_batch = (p->split_k > 1) ? p->split_k : 1;
ck_tile::GroupedConvBwdWeightHostArgs args(param, x, dw, {}, dy, k_batch);
ck_tile::stream_config sc{stream, true, 1, 3, 10};
return SelectedConvBwdWeightLauncher::launch(args, sc);
}
#endif
#ifdef CONV_BWD_WEIGHT_3D_AVAILABLE
static float launch_bwd_weight_3d(
const void* x, const void* dy, void* dw, const ConvProblemC* p, hipStream_t stream)
{
auto param = make_param_3d(p);
const int k_batch = (p->split_k > 1) ? p->split_k : 1;
ck_tile::GroupedConvBwdWeightHostArgs args(param, x, dw, {}, dy, k_batch);
ck_tile::stream_config sc{stream, true, 1, 3, 10};
return ConvBwdWeight3dLauncher::launch(args, sc);
}
#endif
// =========================================================================
// Main dispatch
//
// direction=0 (forward): a=X(input), b=W(weight), c=Y(output)
// direction=1 (bwd_data): a=dY(grad_out), b=W(weight), c=dX(grad_in)
// direction=2 (bwd_weight): a=X(input), b=dY(grad_out), c=dW(grad_wei)
// =========================================================================
float conv_dispatcher_run(
const void* a_ptr, const void* b_ptr, void* c_ptr, const ConvProblemC* prob, void* stream)
{
if(!prob || !a_ptr || !b_ptr || !c_ptr)
return -1.0f;
const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1);
hipStream_t hip_stream = static_cast<hipStream_t>(stream);
try
{
switch(prob->direction)
{
case CONV_FORWARD:
#ifdef CONV_FWD_3D_AVAILABLE
if(is_3d)
return launch_fwd_3d(a_ptr, b_ptr, c_ptr, prob, hip_stream);
#endif
#ifdef CONV_FWD_2D_AVAILABLE
if(!is_3d)
return launch_fwd_2d(a_ptr, b_ptr, c_ptr, prob, hip_stream);
#endif
return -2.0f;
case CONV_BWD_DATA:
#ifdef CONV_BWD_DATA_3D_AVAILABLE
if(is_3d)
return launch_bwd_data_3d(a_ptr, b_ptr, c_ptr, prob, hip_stream);
#endif
#ifdef CONV_BWD_DATA_2D_AVAILABLE
if(!is_3d)
return launch_bwd_data_2d(a_ptr, b_ptr, c_ptr, prob, hip_stream);
#endif
return -2.0f;
case CONV_BWD_WEIGHT:
#ifdef CONV_BWD_WEIGHT_3D_AVAILABLE
if(is_3d)
return launch_bwd_weight_3d(a_ptr, b_ptr, c_ptr, prob, hip_stream);
#endif
#ifdef CONV_BWD_WEIGHT_2D_AVAILABLE
if(!is_3d)
return launch_bwd_weight_2d(a_ptr, b_ptr, c_ptr, prob, hip_stream);
#endif
return -2.0f;
default: return -1.0f;
}
}
catch(const std::exception&)
{
return -3.0f; // Kernel rejected args (e.g. unsupported tile/channel combo)
}
catch(...)
{
return -3.0f;
}
}
} // extern "C"

View File

@@ -9,8 +9,8 @@ Guide for adding support for a new AMD GPU architecture to the CK Tile Dispatche
The dispatcher uses `arch_specs.json` as the **single source of truth** for GPU specifications:
```
arch_specs.json generate_arch_specs.py arch_specs_generated.py (Python)
arch_specs_generated.hpp (C++)
arch_specs.json -> generate_arch_specs.py -> arch_specs_generated.py (Python)
-> arch_specs_generated.hpp (C++)
```
## Quick Start
@@ -175,14 +175,14 @@ for error in result.errors:
```
codegen/
├── arch_specs.json # Single source of truth (EDIT THIS)
├── generate_arch_specs.py # Generator script
├── arch_specs_generated.py # Generated Python module
└── ADDING_NEW_GPU.md # This file
|---- arch_specs.json # Single source of truth (EDIT THIS)
|---- generate_arch_specs.py # Generator script
|---- arch_specs_generated.py # Generated Python module
+---- ADDING_NEW_GPU.md # This file
include/ck_tile/dispatcher/
├── arch_specs_generated.hpp # Generated C++ header
└── arch_filter.hpp # C++ filter
|---- arch_specs_generated.hpp # Generated C++ header
+---- arch_filter.hpp # C++ filter
```
## Best Practices

View File

@@ -1,11 +1,22 @@
# CK Tile GEMM Unified Code Generator
# CK Tile Unified Code Generators
Single source of truth for all GEMM kernel generation.
Single source of truth for GEMM and Grouped Convolution kernel generation.
> **See also:** [Main Dispatcher README](../README.md) for installation and core concepts.
## Shared Infrastructure
Both GEMM and Grouped Conv generators share common code via `codegen_common.py`:
- `TileConfig` - Dataclass for tile dimensions
- `TraitConfigBase` - Base for kernel trait configurations with arch-aware validation
- `CommonTypeMappings` - Dtype-to-C++ type mappings
- `parallel_generate()` - Parallel kernel generation with per-kernel progress logging
- Arch-aware expansion helpers (`valid_wave_configs`, `valid_warp_configs`, etc.)
## Quick Start
### GEMM
```bash
cd dispatcher/codegen
@@ -22,6 +33,25 @@ python3 unified_gemm_codegen.py \
--variants standard preshuffle multi_d
```
### Grouped Convolution
```bash
cd dispatcher/codegen
# Generate forward FP16 grouped conv kernels
python3 unified_grouped_conv_codegen.py \
--output-dir ../build/generated_kernels \
--datatype fp16 \
--variant forward \
--ndim-spatial 2
# Generate backward data kernels
python3 unified_grouped_conv_codegen.py \
--output-dir ../build/generated_kernels \
--variant backward_data \
--ndim-spatial 2
```
## Using from Python
```python
@@ -58,13 +88,13 @@ results = codegen.generate_all()
## Variants
### Standard
Basic GEMM: `C = A × B`
Basic GEMM: `C = A x B`
### PreShuffle
Optimized weight access with LDS pre-shuffling. Best for large matrices.
### Multi-D
Element-wise fusion: `C = op(A × B + D0 + D1 + ...)`
Element-wise fusion: `C = op(A x B + D0 + D1 + ...)`
Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh`
@@ -72,10 +102,11 @@ Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh`
```
generated_kernels/
├── gemm_fp16_rcr_compv4_..._128x128x32_....hpp
├── gemm_fp16_rcr_compv4_..._preshuffle.hpp
├── gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp
└── ...
|---- gemm_fp16_rcr_compv4_..._128x128x32_....hpp # GEMM kernels
|---- gemm_fp16_rcr_compv4_..._preshuffle.hpp
|---- gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp
|---- grouped_conv_fwd_fp16_nhwgc_..._128x128x32_....hpp # Grouped conv kernels
+---- ...
```
## Configuration Files

View File

@@ -0,0 +1,350 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Shared codegen infrastructure for GEMM and grouped convolution code generators.
Extracted from unified_gemm_codegen.py + arch-aware expansion helpers from conv.
Both unified_gemm_codegen.py and unified_grouped_conv_codegen.py import from here
to eliminate duplication.
"""
import logging
import concurrent.futures
from dataclasses import dataclass
from typing import (
Callable,
ClassVar,
Dict,
FrozenSet,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)
log = logging.getLogger(__name__)
T = TypeVar("T")
R = TypeVar("R")
ANY_INT = -1
# ============================================================================
# Tile and Trait Configuration (shared between GEMM and Conv)
# ============================================================================
@dataclass
class TileConfig:
"""Tile configuration parameters shared by GEMM and grouped conv."""
tile_m: int
tile_n: int
tile_k: int
warp_m: int
warp_n: int
warp_k: int
warp_tile_m: int
warp_tile_n: int
warp_tile_k: int
def is_valid(self) -> bool:
if self.tile_m <= 0 or self.tile_n <= 0 or self.tile_k <= 0:
return False
return (
self.tile_m % (self.warp_m * self.warp_tile_m) == 0
and self.tile_n % (self.warp_n * self.warp_tile_n) == 0
and self.tile_k % (self.warp_k * self.warp_tile_k) == 0
)
@dataclass
class TraitConfigBase:
"""
Base kernel trait configuration shared by GEMM and grouped conv.
GEMM extends this with ``persistent``; grouped conv extends with
``double_smem_buffer`` and ``num_groups_to_merge``.
"""
pipeline: str # mem, compv3, compv4, compv5, ...
epilogue: str # cshuffle, default
scheduler: str # intrawave, interwave
pad_m: bool
pad_n: bool
pad_k: bool
# Unsupported (pipeline, epilogue, scheduler) combinations.
# Only 'mem' and 'basic_v1' pipelines support interwave; all compute
# pipelines (compv3/v4/v5/v6/async) only support intrawave.
_UNSUPPORTED: ClassVar[FrozenSet] = frozenset(
{
("compv3", "cshuffle", "interwave"),
("compv3", "default", "interwave"),
("compv4", "cshuffle", "interwave"),
("compv4", "default", "interwave"),
("compv5", "cshuffle", "interwave"),
("compv5", "default", "interwave"),
("compv6", "cshuffle", "interwave"),
("compv6", "default", "interwave"),
("comp_async", "cshuffle", "interwave"),
("comp_async", "default", "interwave"),
("basic_async_v1", "cshuffle", "interwave"),
("basic_async_v1", "default", "interwave"),
}
)
def is_valid(self) -> bool:
return (self.pipeline, self.epilogue, self.scheduler) not in self._UNSUPPORTED
# ============================================================================
# Type Mappings (centralized for both GEMM and conv codegen)
# ============================================================================
class CommonTypeMappings:
"""Centralized type mappings shared by GEMM and grouped conv codegen."""
DTYPE_TO_CK = {
"fp16": "fp16_t",
"bf16": "bf16_t",
"fp32": "float",
"fp8": "fp8_t",
"bf8": "bf8_t",
"int8": "int8_t",
}
DTYPE_TO_CK_QUALIFIED = {
"fp16": "ck_tile::fp16_t",
"bf16": "ck_tile::bf16_t",
"fp32": "float",
"fp8": "ck_tile::fp8_t",
"bf8": "ck_tile::bf8_t",
"int8": "int8_t",
}
DTYPE_TO_DISPATCHER = {
"fp16": "DataType::FP16",
"bf16": "DataType::BF16",
"fp32": "DataType::FP32",
"fp8": "DataType::FP8",
"bf8": "DataType::BF8",
"int8": "DataType::INT8",
}
# GEMM-specific layout mappings ("r"/"c" for row/column major).
# Convolution layouts (NHWGC, GKYXC, etc.) are handled by
# unified_grouped_conv_codegen.py via GroupedConvLayout / GroupedConvTypeMappings.
GEMM_LAYOUT_TO_CK = {
"r": "tensor_layout::gemm::RowMajor",
"c": "tensor_layout::gemm::ColumnMajor",
}
LAYOUT_TO_CK = GEMM_LAYOUT_TO_CK # backward compat alias
GEMM_LAYOUT_TO_DISPATCHER = {
"r": "LayoutTag::RowMajor",
"c": "LayoutTag::ColMajor",
}
LAYOUT_TO_DISPATCHER = GEMM_LAYOUT_TO_DISPATCHER # backward compat alias
# GEMM-only pipeline mappings (used by unified_gemm_codegen.py).
# Convolution pipelines are in GroupedConvTypeMappings
# (unified_grouped_conv_codegen.py). CK Tile conv supports:
# BASIC_V1, Mem, CompV3, CompV4, CompV5, CompV6, ASYNC_V1, ASYNC_V4.
# The dispatcher currently generates: mem, compv3, compv4.
# preshufflev2 is GEMM-only (weight pre-shuffle for GEMM, not conv).
PIPELINE_TO_CK = {
"mem": "GemmPipelineAgBgCrMem",
"compv3": "GemmPipelineAgBgCrCompV3",
"compv4": "GemmPipelineAgBgCrCompV4",
"compv5": "GemmPipelineAgBgCrCompV5",
"preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2",
}
PIPELINE_TO_BASE = {
"mem": "BaseGemmPipelineAgBgCrMem",
"compv3": "BaseGemmPipelineAgBgCrCompV3",
"compv4": "BaseGemmPipelineAgBgCrCompV4",
"compv5": "BaseGemmPipelineAgBgCrCompV5",
"preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2",
}
PIPELINE_TO_DISPATCHER = {
"mem": "Pipeline::Mem",
"compv3": "Pipeline::CompV3",
"compv4": "Pipeline::CompV4",
"compv5": "Pipeline::CompV5",
"preshufflev2": "Pipeline::PreShuffleV2",
}
SCHEDULER_TO_CK = {
"intrawave": "GemmPipelineScheduler::Intrawave",
"interwave": "GemmPipelineScheduler::Interwave",
"default": "GemmPipelineScheduler::Default",
}
SCHEDULER_TO_DISPATCHER = {
"intrawave": "Scheduler::Intrawave",
"interwave": "Scheduler::Interwave",
"default": "Scheduler::Auto",
}
EPILOGUE_TO_DISPATCHER = {
"cshuffle": "Epilogue::CShuffle",
"default": "Epilogue::Default",
}
@staticmethod
def get_output_dtype(dtype: str) -> str:
"""Get output datatype (fp8/bf8 -> fp16)."""
return "fp16" if dtype in ("fp8", "bf8") else dtype
# ============================================================================
# Code Generation Helpers
# ============================================================================
def generate_cpp_compilation_unit(kernel_name: str) -> str:
"""Generate a .cpp compilation unit that includes a kernel header.
This is the standard pattern: one .cpp per kernel that just includes
the generated .hpp header, causing template instantiation.
"""
return (
f"// Auto-generated compilation unit for {kernel_name}\n"
f'#include "{kernel_name}.hpp"\n'
)
def parallel_generate(
generate_fn: Callable[[T], R],
items: Sequence[T],
parallel: bool = True,
) -> List[R]:
"""Run ``generate_fn`` over ``items``, optionally in parallel.
Logs per-item progress (best-of-conv pattern).
Returns a flat list of results in completion order.
"""
results: List[R] = []
if not items:
return results
if parallel and len(items) > 1:
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {executor.submit(generate_fn, item): item for item in items}
for future in concurrent.futures.as_completed(futures):
result = future.result()
results.append(result)
log.info("Generated: %s", futures[future])
else:
for item in items:
result = generate_fn(item)
results.append(result)
log.info("Generated: %s", item)
return results
# ============================================================================
# Arch-Aware Expansion Helpers (adopted from conv kernel_decl.hpp)
# ============================================================================
# These load from arch_specs_generated when available, falling back to
# hardcoded defaults that match the most common arch (gfx942).
_arch_data_cache: Optional[Dict] = None
def _get_arch_data() -> Dict:
"""Load arch filter data, with caching."""
global _arch_data_cache
if _arch_data_cache is not None:
return _arch_data_cache
try:
from arch_specs_generated import (
WARP_SUPPORTED_COMBINATIONS,
WARP_TILE_SUPPORTED_COMBINATIONS,
TRAIT_UNSUPPORTED_COMBINATIONS,
get_supported_archs,
)
_arch_data_cache = {
"warp_combos": WARP_SUPPORTED_COMBINATIONS,
"warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS,
"trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS,
"supported_archs": get_supported_archs(),
}
except ImportError:
_arch_data_cache = {
"warp_combos": {
"gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
"gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
},
"warp_tile_combos": {
"gfx942": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]},
"gfx90a": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]},
},
"trait_unsupported": {
("compv3", "cshuffle", "interwave"),
("compv4", "cshuffle", "interwave"),
},
"supported_archs": ["gfx90a", "gfx942", "gfx950"],
}
return _arch_data_cache
def valid_wave_configs(arch: str) -> List[List[int]]:
"""Return valid [wave_m, wave_n, wave_k] combos for *arch*."""
data = _get_arch_data()
return data["warp_combos"].get(arch, [[2, 2, 1]])
def valid_warp_configs(arch: str, dtype: str) -> List[List[int]]:
"""Return valid [warp_tile_m, warp_tile_n, warp_tile_k] combos for *arch*/*dtype*.
The dtype key is constructed as ``{dtype}_{dtype}_{acc}`` where acc is
fp32 for float types and int32 for int8.
"""
data = _get_arch_data()
acc = "int32" if dtype == "int8" else "fp32"
dtype_key = f"{dtype}_{dtype}_{acc}"
arch_tiles = data["warp_tile_combos"].get(arch, {})
return arch_tiles.get(dtype_key, [[32, 32, 16]])
def valid_trait_configs() -> List[Tuple[str, str]]:
"""Return valid (pipeline, scheduler) pairs.
Compute pipelines only support intrawave; mem supports both.
"""
return [
("compv3", "intrawave"),
("compv4", "intrawave"),
("compv5", "intrawave"),
("mem", "intrawave"),
("mem", "interwave"),
]
def needs_wave_expansion(config: dict) -> bool:
"""True if wave_m or wave_n is a wildcard (ANY_INT = -1)."""
return config.get("wave_m", 2) == ANY_INT or config.get("wave_n", 2) == ANY_INT
def needs_warp_expansion(config: dict) -> bool:
"""True if warp_m or warp_n is a wildcard (ANY_INT = -1)."""
return config.get("warp_m", 32) == ANY_INT or config.get("warp_n", 32) == ANY_INT
def needs_pipeline_expansion(config: dict) -> bool:
"""True if pipeline is a wildcard (\"*\")."""
return config.get("pipeline", "compv4") == "*"

View File

@@ -109,7 +109,7 @@ inline void register_all_kernels()
"""
output_file.write_text(content)
print(f" Generated registration header: {output_file}")
print(f"OK Generated registration header: {output_file}")
def generate_registration_cpp(kernels: List[KernelConfig], output_file: Path):
@@ -143,7 +143,7 @@ namespace generated {
"""
output_file.write_text(content)
print(f" Generated registration implementation: {output_file}")
print(f"OK Generated registration implementation: {output_file}")
def generate_kernel_wrapper_header(kernel: KernelConfig, output_dir: Path):
@@ -414,8 +414,8 @@ def main():
with open(manifest_output, "w") as f:
json.dump(manifest_data, f, indent=2)
print(f" Generated manifest: {manifest_output}")
print("\n Registration code generation complete!")
print(f"OK Generated manifest: {manifest_output}")
print("\nOK Registration code generation complete!")
print(f" Total kernels: {len(kernels)}")
print(" Output files:")
print(f" - {registration_header}")

View File

@@ -17,10 +17,10 @@ Usage:
Output structure:
build/kernel_wrappers/
├── gemm_fp16_rcr_128x128x32.cpp
├── gemm_fp16_rcr_256x256x64.cpp
├── conv_fwd_fp16_2d_128x128.cpp
└── ...
|---- gemm_fp16_rcr_128x128x32.cpp
|---- gemm_fp16_rcr_256x256x64.cpp
|---- conv_fwd_fp16_2d_128x128.cpp
+---- ...
Each .cpp simply includes its corresponding .hpp and forces symbol emission.
"""

View File

@@ -359,8 +359,8 @@ class ConvTraitConfig:
@dataclass
class ConvKernelConfig:
"""Complete convolution kernel configuration"""
class GroupedConvKernelConfig:
"""Complete grouped convolution kernel configuration"""
tile: ConvTileConfig = field(default_factory=ConvTileConfig)
trait: ConvTraitConfig = field(default_factory=ConvTraitConfig)
@@ -419,7 +419,11 @@ class ConvKernelConfig:
def kernel_name(self) -> str:
"""Generate kernel name from config"""
variant_map = {"forward": "fwd", "bwd_data": "bwdd", "bwd_weight": "bwdw"}
variant_map = {
"forward": "fwd",
"bwd_data": "bwd_data",
"bwd_weight": "bwd_weight",
}
var_str = variant_map.get(self.variant, self.variant)
name = f"conv_{var_str}_{self.dtype_input}_{self.ndim}d"
@@ -433,11 +437,11 @@ class ConvKernelConfig:
@dataclass
class ConvKernelConfigSet:
class GroupedConvKernelConfigSet:
"""A set of convolution kernel configurations loaded from JSON"""
name: str = "default"
configs: List[ConvKernelConfig] = field(default_factory=list)
configs: List[GroupedConvKernelConfig] = field(default_factory=list)
# Tile parameter ranges
tile_m_values: List[int] = field(default_factory=lambda: [128])
@@ -481,7 +485,7 @@ class ConvKernelConfigSet:
layout: str = "nhwgc"
gpu_targets: List[str] = field(default_factory=lambda: ["gfx942"])
def generate_configs(self) -> Iterator[ConvKernelConfig]:
def generate_configs(self) -> Iterator[GroupedConvKernelConfig]:
"""Generate all kernel configurations (cartesian product)"""
# Tile parameters
tile_params = itertools.product(
@@ -548,7 +552,7 @@ class ConvKernelConfigSet:
double_smem_buffer=trait[6],
num_groups_to_merge=trait[7],
)
yield ConvKernelConfig(
yield GroupedConvKernelConfig(
tile=tile_cfg,
trait=trait_cfg,
dtype_input=self.dtype_input,
@@ -599,7 +603,9 @@ class ConvKernelConfigSet:
return tile_count * trait_count * extra_count * len(self.gpu_targets)
def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet:
def load_grouped_conv_kernel_configs(
json_path: str | Path,
) -> GroupedConvKernelConfigSet:
"""
Load convolution kernel configurations from a JSON file.
@@ -607,14 +613,14 @@ def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet:
json_path: Path to JSON configuration file
Returns:
ConvKernelConfigSet with all parameter values loaded
GroupedConvKernelConfigSet with all parameter values loaded
"""
json_path = Path(json_path)
with open(json_path) as f:
data = json.load(f)
config_set = ConvKernelConfigSet()
config_set = GroupedConvKernelConfigSet()
# Name
config_set.name = data.get("kernel_set_name", json_path.stem)
@@ -680,15 +686,15 @@ def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet:
def generate_cpp_conv_kernel_set_declaration(
config_set: ConvKernelConfigSet,
config_set: GroupedConvKernelConfigSet,
set_name: Optional[str] = None,
) -> str:
"""
Generate C++ DECL_CONV_KERNEL_SET code from a ConvKernelConfigSet.
Generate C++ DECL_GROUPED_CONV_KERNEL_SET code from a GroupedConvKernelConfigSet.
"""
name = set_name or config_set.name
lines = [f"DECL_CONV_KERNEL_SET({name},"]
lines = [f"DECL_GROUPED_CONV_KERNEL_SET({name},"]
for config in config_set.generate_configs():
line = f' .add("{config.dtype_input}", "{config.variant}", {config.ndim}, '

View File

@@ -7,7 +7,7 @@
Unified GEMM Code Generator - Single Source of Truth
This is THE unified code generator for all GEMM kernel variants:
- Standard GEMM (C = A × B)
- Standard GEMM (C = A x B)
- Preshuffle GEMM (optimized weight access)
- Multi-D GEMM (element-wise fusion)
@@ -25,6 +25,12 @@ from dataclasses import dataclass, asdict
from enum import Enum
import concurrent.futures
from codegen_common import (
TileConfig,
TraitConfigBase,
CommonTypeMappings as TypeMappings,
)
# Import architecture filter for GPU-specific validation
try:
from arch_filter import ArchFilter, KernelConfig as ArchKernelConfig, OperatorType
@@ -194,62 +200,14 @@ class GemmVariant(Enum):
MULTI_D = "multi_d"
@dataclass
class TileConfig:
"""Tile configuration parameters"""
tile_m: int
tile_n: int
tile_k: int
warp_m: int
warp_n: int
warp_k: int
warp_tile_m: int
warp_tile_n: int
warp_tile_k: int
def is_valid(self) -> bool:
"""Validate tile configuration"""
return (
self.tile_m % (self.warp_m * self.warp_tile_m) == 0
and self.tile_n % (self.warp_n * self.warp_tile_n) == 0
and self.tile_k % (self.warp_k * self.warp_tile_k) == 0
and self.tile_m > 0
and self.tile_n > 0
and self.tile_k > 0
)
# TileConfig imported from codegen_common
@dataclass
class TraitConfig:
"""Kernel trait configuration"""
class TraitConfig(TraitConfigBase):
"""GEMM-specific trait configuration extending TraitConfigBase with persistent mode."""
pipeline: str # mem, compv3, compv4
epilogue: str # default, cshuffle
scheduler: str # intrawave, interwave
pad_m: bool
pad_n: bool
pad_k: bool
persistent: bool
def is_valid(self) -> bool:
"""Check if trait combination is valid"""
# Unsupported combinations
# Only 'mem' pipeline supports interwave scheduler.
# All compute pipelines (compv3/v4/v5/v6/async) only support intrawave.
unsupported = {
("compv3", "cshuffle", "interwave"),
("compv3", "default", "interwave"),
("compv4", "cshuffle", "interwave"),
("compv4", "default", "interwave"),
("compv5", "cshuffle", "interwave"),
("compv5", "default", "interwave"),
("compv6", "cshuffle", "interwave"),
("compv6", "default", "interwave"),
("comp_async", "cshuffle", "interwave"),
("comp_async", "default", "interwave"),
}
return (self.pipeline, self.epilogue, self.scheduler) not in unsupported
persistent: bool = False
@dataclass
@@ -345,89 +303,7 @@ class KernelConfig:
# ============================================================================
class TypeMappings:
"""Centralized type mappings for code generation"""
DTYPE_TO_CK = {
"fp16": "fp16_t",
"bf16": "bf16_t",
"fp32": "float",
"fp8": "fp8_t",
"bf8": "bf8_t",
"int8": "int8_t",
}
# Fully-qualified types for use outside of 'using namespace ck_tile' scope
DTYPE_TO_CK_QUALIFIED = {
"fp16": "ck_tile::fp16_t",
"bf16": "ck_tile::bf16_t",
"fp32": "float", # Built-in type, no namespace
"fp8": "ck_tile::fp8_t",
"bf8": "ck_tile::bf8_t",
"int8": "int8_t", # Built-in type
}
DTYPE_TO_DISPATCHER = {
"fp16": "DataType::FP16",
"bf16": "DataType::BF16",
"fp32": "DataType::FP32",
"fp8": "DataType::FP8",
"bf8": "DataType::BF8",
"int8": "DataType::INT8",
}
LAYOUT_TO_CK = {
"r": "tensor_layout::gemm::RowMajor",
"c": "tensor_layout::gemm::ColumnMajor",
}
LAYOUT_TO_DISPATCHER = {
"r": "LayoutTag::RowMajor",
"c": "LayoutTag::ColMajor",
}
PIPELINE_TO_CK = {
"mem": "GemmPipelineAgBgCrMem",
"compv3": "GemmPipelineAgBgCrCompV3",
"compv4": "GemmPipelineAgBgCrCompV4",
"preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2",
}
PIPELINE_TO_BASE = {
"mem": "BaseGemmPipelineAgBgCrMem",
"compv3": "BaseGemmPipelineAgBgCrCompV3",
"compv4": "BaseGemmPipelineAgBgCrCompV4",
"preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2",
}
PIPELINE_TO_DISPATCHER = {
"mem": "Pipeline::Mem",
"compv3": "Pipeline::CompV3",
"compv4": "Pipeline::CompV4",
"preshufflev2": "Pipeline::PreShuffleV2",
}
SCHEDULER_TO_CK = {
"intrawave": "GemmPipelineScheduler::Intrawave",
"interwave": "GemmPipelineScheduler::Interwave",
"default": "GemmPipelineScheduler::Default",
}
SCHEDULER_TO_DISPATCHER = {
"intrawave": "Scheduler::Intrawave",
"interwave": "Scheduler::Interwave",
"default": "Scheduler::Auto",
}
EPILOGUE_TO_DISPATCHER = {
"cshuffle": "Epilogue::CShuffle",
"default": "Epilogue::Default",
}
@staticmethod
def get_output_dtype(dtype: str) -> str:
"""Get output datatype (fp8/bf8 -> fp16)"""
return "fp16" if dtype in ["fp8", "bf8"] else dtype
# TypeMappings imported from codegen_common as CommonTypeMappings -> TypeMappings alias
# ============================================================================
@@ -1068,7 +944,11 @@ class UnifiedGemmCodegen:
}
def generate_all(self, parallel: bool = True) -> Dict:
"""Generate all kernels"""
"""Generate all kernels.
When parallel=True, all configs across all variants are collected first,
then generated concurrently in a single thread pool for maximum throughput.
"""
log.info("Generating GEMM kernels:")
log.info(f" Datatype: {self.datatype}")
log.info(f" Layout: {self.layout}")
@@ -1078,49 +958,24 @@ class UnifiedGemmCodegen:
results = {"kernels": [], "wrappers": [], "failed": []}
# Get configurations
# Collect ALL configs across all variants/preselected sets upfront
all_configs = []
if self.use_preselected:
configs = self._get_preselected_configs()
log.info(f" Total configurations: {len(configs)}")
all_configs = self._get_preselected_configs()
log.info(f" Total configurations: {len(all_configs)}")
else:
for variant in self.variants:
log.info(f"\nGenerating {variant.value} kernels...")
configs = self._get_configs_for_variant(variant)
log.info(f" Configurations: {len(configs)}")
log.info(f" {variant.value}: {len(configs)} configurations")
all_configs.extend(configs)
log.info(f" Total across all variants: {len(all_configs)}")
if parallel:
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(self._generate_one, cfg) for cfg in configs
]
for future in concurrent.futures.as_completed(futures):
try:
k, w = future.result()
results["kernels"].append(k)
results["wrappers"].append(w)
except Exception as e:
results["failed"].append(str(e))
log.error(f"Failed: {e}")
else:
for cfg in configs:
try:
k, w = self._generate_one(cfg)
results["kernels"].append(k)
results["wrappers"].append(w)
except Exception as e:
results["failed"].append(str(e))
log.error(f"Failed: {e}")
# Generate registration header
if results["wrappers"]:
self._generate_registration_header(results["wrappers"])
return results
# Generate from preselected set
if parallel:
# Generate all configs in a single parallel pass
if parallel and all_configs:
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(self._generate_one, cfg) for cfg in configs]
futures = [
executor.submit(self._generate_one, cfg) for cfg in all_configs
]
for future in concurrent.futures.as_completed(futures):
try:
k, w = future.result()
@@ -1130,7 +985,7 @@ class UnifiedGemmCodegen:
results["failed"].append(str(e))
log.error(f"Failed: {e}")
else:
for cfg in configs:
for cfg in all_configs:
try:
k, w = self._generate_one(cfg)
results["kernels"].append(k)
@@ -1139,7 +994,6 @@ class UnifiedGemmCodegen:
results["failed"].append(str(e))
log.error(f"Failed: {e}")
# Generate registration header
if results["wrappers"]:
self._generate_registration_header(results["wrappers"])
@@ -1638,12 +1492,19 @@ def main():
# Write to temp file and use as config
import tempfile
import os as _os
with tempfile.NamedTemporaryFile(
_tmp_config = tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
json.dump(full_config, f)
args.config = Path(f.name)
)
try:
json.dump(full_config, _tmp_config)
_tmp_config.close()
args.config = Path(_tmp_config.name)
except Exception:
_tmp_config.close()
_os.unlink(_tmp_config.name)
raise
except json.JSONDecodeError as e:
logging.error(f"Invalid tile-config-json: {e}")
return 1
@@ -1672,7 +1533,7 @@ def main():
results = codegen.generate_all(parallel=not args.no_parallel)
logging.info("\nGeneration complete!")
logging.info("\nGeneration complete.")
logging.info(f" Kernels: {len(results['kernels'])}")
logging.info(f" Wrappers: {len(results['wrappers'])}")
logging.info(f" Failed: {len(results['failed'])}")
@@ -1684,7 +1545,7 @@ def main():
# Generate dispatcher registration if requested
if args.register:
logging.info("\n📝 Generating dispatcher registration code...")
logging.info("\nGenerating dispatcher registration code...")
try:
from generate_dispatcher_registration import (
scan_generated_headers,
@@ -1701,11 +1562,20 @@ def main():
)
generate_registration_cpp(kernels, reg_dir / "dispatcher_registration.cpp")
logging.info(f"Generated registration code for {len(kernels)} kernels")
logging.info(f"Generated registration code for {len(kernels)} kernels")
except Exception as e:
logging.error(f"Failed to generate registration code: {e}")
return 1
# Clean up temp config file if we created one
if args.tile_config_json and args.config and args.config.exists():
try:
import os as _os
_os.unlink(args.config)
except OSError:
pass
return 0 if not results["failed"] else 1

File diff suppressed because it is too large Load Diff

View File

@@ -187,7 +187,6 @@ function(add_gpu_example NAME SOURCE KERNEL_HEADER)
if(HEADER_NAME STREQUAL "register_all_kernels.hpp")
# Registration header - examples include it directly
target_compile_options(${NAME} PRIVATE
-DGEMM_KERNEL_AVAILABLE=1
-mllvm -enable-noalias-to-md-conversion=0
-Wno-undefined-func-template
-Wno-float-equal
@@ -315,6 +314,7 @@ function(add_declarative_gpu_example NAME SOURCE)
target_include_directories(${NAME} PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/../../include
${CMAKE_CURRENT_SOURCE_DIR}/../include
${CMAKE_CURRENT_SOURCE_DIR}/../..
${EXAMPLE_KERNEL_DIR}
${EXAMPLE_KERNEL_DIR}/dispatcher_wrappers
)
@@ -322,7 +322,6 @@ function(add_declarative_gpu_example NAME SOURCE)
# Force-include the generated registration header
target_compile_options(${NAME} PRIVATE
-include ${EXAMPLE_HEADER}
-DGEMM_KERNEL_AVAILABLE=1
-mllvm -enable-noalias-to-md-conversion=0
-Wno-undefined-func-template
-Wno-float-equal
@@ -345,6 +344,7 @@ add_declarative_gpu_example(gemm_03_benchmark_validation gemm/cpp/03_benchmark_v
add_declarative_gpu_example(gemm_04_heuristics gemm/cpp/04_heuristics.cpp)
add_declarative_gpu_example(gemm_05_json_export gemm/cpp/05_json_export.cpp)
add_declarative_gpu_example(gemm_06_multi_registry gemm/cpp/06_multi_registry.cpp)
add_declarative_gpu_example(gemm_07_gfx950_minimal gemm/cpp/07_gfx950_minimal.cpp)
# ML Heuristic example -- requires LightGBM shared library
# Derive site-packages from active Python interpreter (respects virtualenvs)
@@ -443,19 +443,79 @@ if(hip_FOUND)
endif()
add_dependencies(dispatcher_gemm_lib generate_gemm_fallback_kernel)
# =============================================================================
# Grouped Convolution C++ Examples
# =============================================================================
add_declarative_gpu_example(grouped_conv_01_basic grouped_conv/cpp/01_basic_grouped_conv.cpp)
add_declarative_gpu_example(grouped_conv_02_all_dirs grouped_conv/cpp/02_all_directions.cpp)
add_declarative_gpu_example(grouped_conv_03_bench_val grouped_conv/cpp/03_benchmark_validation.cpp)
add_declarative_gpu_example(grouped_conv_04_registry_json grouped_conv/cpp/04_registry_json.cpp)
add_declarative_gpu_example(grouped_conv_05_bwd_data grouped_conv/cpp/05_bwd_data.cpp)
add_declarative_gpu_example(grouped_conv_06_bwd_weight grouped_conv/cpp/06_bwd_weight.cpp)
add_declarative_gpu_example(grouped_conv_07_benchmark grouped_conv/cpp/07_multi_tile_benchmark.cpp)
# =============================================================================
# Grouped Convolution Python Library - Multi-Kernel (fwd/bwd_data/bwd_weight x 2D/3D)
# =============================================================================
# Kernel output directory for the Python conv library
set(CONV_FALLBACK_KERNEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/conv_python_fallback")
set(CONV_DISPATCH_HEADER "${CONV_FALLBACK_KERNEL_DIR}/conv_python_dispatch.hpp")
# Generate ALL conv kernels (fwd/bwd_data/bwd_weight x 2D/3D x multiple tile configs)
# then create the dispatch header with 2D/3D aliases
add_custom_command(
OUTPUT ${CONV_DISPATCH_HEADER}
COMMAND ${CMAKE_COMMAND} -E make_directory ${CONV_FALLBACK_KERNEL_DIR}
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_grouped_conv_codegen.py
--variant forward bwd_data bwd_weight --ndim 2 3
--datatype fp16 --arch ${GPU_TARGET}
--output ${CONV_FALLBACK_KERNEL_DIR}
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/generate_conv_dispatch_header.py
--kernel-dir ${CONV_FALLBACK_KERNEL_DIR}
--output ${CONV_DISPATCH_HEADER}
COMMENT "Generating conv kernels (fwd/bwd_data/bwd_weight x 2D/3D) for Python library..."
VERBATIM
)
add_custom_target(generate_conv_fallback_kernels DEPENDS ${CONV_DISPATCH_HEADER})
# Conv dynamic library for Python (all 6 kernel variants)
add_library(dispatcher_conv_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/../bindings/ctypes/conv_ctypes_lib.cpp)
target_link_libraries(dispatcher_conv_lib PRIVATE ck_tile_dispatcher)
target_include_directories(dispatcher_conv_lib PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/../../include
${CMAKE_CURRENT_SOURCE_DIR}/../include
${CONV_FALLBACK_KERNEL_DIR}
)
target_compile_options(dispatcher_conv_lib PRIVATE
-include ${CONV_DISPATCH_HEADER}
-DGFX_ARCH="${GPU_TARGET}"
-mllvm -enable-noalias-to-md-conversion=0
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
)
if(hip_FOUND)
target_link_libraries(dispatcher_conv_lib PRIVATE hip::device hip::host)
endif()
add_dependencies(dispatcher_conv_lib generate_conv_fallback_kernels)
message(STATUS "GEMM examples configured - kernels will be generated during 'make'")
message(STATUS "Grouped Conv examples configured - kernels will be generated during 'make'")
# Convenience target to build all Python ctypes libraries
add_custom_target(python_libs
DEPENDS dispatcher_gemm_lib
COMMENT "Building Python ctypes libraries (GEMM)"
DEPENDS dispatcher_gemm_lib dispatcher_conv_lib
COMMENT "Building Python ctypes libraries (GEMM + Conv)"
)
# =============================================================================
# Per-Architecture Kernel Generation Targets
# =============================================================================
set(SUPPORTED_GPU_ARCHS gfx942 gfx90a gfx1100 gfx1030)
set(SUPPORTED_GPU_ARCHS gfx942 gfx950 gfx90a gfx1100 gfx1030)
foreach(ARCH ${SUPPORTED_GPU_ARCHS})
# GEMM kernels for this arch

View File

@@ -1,8 +1,6 @@
# CK Tile Dispatcher Examples
Comprehensive examples for GEMM operations with GPU execution.
> **Note**: Convolution examples have been moved to `ck-2/conv_archive/` for reference.
Comprehensive examples for GEMM and Grouped Convolution operations with GPU execution.
---
@@ -60,11 +58,11 @@ python3 examples/gemm/python/08_heuristics.py
```
examples/
├── gemm/
├── cpp/ # 6 C++ GEMM examples
└── python/ # 11 Python GEMM examples
└── README.md
|---- gemm/
| |---- cpp/ # 6 C++ GEMM examples
| +---- python/ # 11 Python GEMM examples
|
+---- README.md
```
---
@@ -201,10 +199,31 @@ rocminfo | grep "Name:"
---
## Archived Examples
## Grouped Convolution
Convolution examples have been archived to `ck-2/conv_archive/dispatcher/`:
- `examples/conv/cpp/` - 11 C++ convolution examples
- `examples/conv/python/` - 14 Python convolution examples
Grouped convolution support has been re-introduced with a unified infrastructure shared with GEMM.
See the archive for convolution functionality reference.
### Infrastructure
The grouped convolution code generation, utilities, and build scripts are available:
| Component | Location |
|-----------|----------|
| C++ Headers | `include/ck_tile/dispatcher/grouped_conv_*.hpp` |
| Python Codegen | `codegen/unified_grouped_conv_codegen.py` |
| Python Utils | `python/grouped_conv_utils.py` |
| Build Script | `scripts/compile_grouped_conv_examples.py` |
### Building Grouped Conv Kernels
```bash
# Generate grouped conv kernels
python3 codegen/unified_grouped_conv_codegen.py \
--output-dir build/generated_kernels \
--datatype fp16 --variant forward --ndim-spatial 2
# Compile a grouped conv example
python3 scripts/compile_grouped_conv_examples.py my_grouped_conv_example.cpp
```
See the [main README](../README.md#grouped-convolution-support) for more details.

View File

@@ -21,9 +21,9 @@
* - pipeline: "compv3" -> 1 option (compv4 requires special handling)
* - scheduler: "intrawave" -> 1 option
*
* Raw expansion: 3 × 2 = 6 configs, but arch filter validates each:
* - tile_m must be divisible by (warp_m × warp_tile_m)
* - tile_n must be divisible by (warp_n × warp_tile_n)
* Raw expansion: 3 x 2 = 6 configs, but arch filter validates each:
* - tile_m must be divisible by (warp_m x warp_tile_m)
* - tile_n must be divisible by (warp_n x warp_tile_n)
* - Some wave/warp combos invalid: (4,1,1)+(32,32,16), (1,4,1)+(32,32,16)
* Result: 4 valid wildcard kernels + 1 explicit = 5 total
*
@@ -70,13 +70,13 @@ DECL_KERNEL_SET(multi_size_kernels,
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(64, 64, 64)
.wave(ANY_INT, ANY_INT, 1) // ANY_INT (1,4,1), (2,2,1), (4,1,1)
.warp(-1, -1, -1) // -1 same as ANY_INT (16,16,32), (32,32,16)
.pipeline("*") // "*" valid pipelines
.scheduler("*") // "*" valid schedulers
.wave(ANY_INT, ANY_INT, 1) // ANY_INT -> (1,4,1), (2,2,1), (4,1,1)
.warp(-1, -1, -1) // -1 same as ANY_INT -> (16,16,32), (32,32,16)
.pipeline("*") // "*" -> valid pipelines
.scheduler("*") // "*" -> valid schedulers
.epilogue("cshuffle"),
"gfx942"));
// Raw: 3×2=6, arch filter removes 2 invalid 4 valid kernels
// Raw: 3x2=6, arch filter removes 2 invalid -> 4 valid kernels
// =============================================================================
// MAIN
@@ -116,8 +116,8 @@ int main(int argc, char* argv[])
.pipeline("*") -> expands to valid pipelines = 1
.scheduler("*") -> expands to valid schedulers = 1
Expanded: 3 × 2 = 6 configs, but arch filter validates each:
- wave×warp must divide tile: (4,1,1)×(32,32,16) invalid for 64x64
Expanded: 3 x 2 = 6 configs, but arch filter validates each:
- wave x warp must divide tile: (4,1,1)x(32,32,16) invalid for 64x64
- Result: 4 valid kernels from wildcard + 1 explicit = 5 total
)";

View File

@@ -0,0 +1,191 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* Example 07: Minimal gfx950 (CDNA4 / MI350) GEMM
*
* Demonstrates the dispatcher working with gfx950-specific kernels:
*
* - fp16 GEMM with standard tile configs
* - fp8 GEMM with gfx950-extended warp tiles (16x16x128)
* - 160KB LDS: gfx950 doubles the LDS from 64KB to 160KB
*
* Build: cd dispatcher/build && cmake .. -DGPU_TARGETS=gfx950 && make gemm_07_gfx950_minimal
*/
#include <hip/hip_runtime.h>
#include <iostream>
#include <iomanip>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/kernel_decl.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::backends;
using namespace ck_tile::dispatcher::utils;
using Signature = decl::Signature;
using Algorithm = decl::Algorithm;
// =============================================================================
// gfx950-targeted kernel declarations
// =============================================================================
DECL_KERNEL_SET(gfx950_gemm_kernels,
// fp16 128x128x32 -- bread-and-butter config, works on all CDNA
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 128, 32)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx950")
// fp16 128x128x64 -- deeper K tile using more LDS
// LDS usage: 128*64*2 + 128*64*2 = 32768 bytes (fits 64KB, gfx950 has 160KB)
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 128, 64)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx950")
// fp16 64x64x32 -- small-tile variant for small problems
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(64, 64, 32)
.wave(2, 2, 1)
.warp(16, 16, 32)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx950"));
// =============================================================================
// MAIN
// =============================================================================
int main(int argc, char* argv[])
{
ExampleArgs args("Example 07: gfx950 Minimal GEMM",
"Demonstrates gfx950 (CDNA4 / MI350) dispatcher");
args.add_flag("--list", "List registered kernels");
args.add_flag("--list-verbose", "List registered kernels with full details");
args.add_option("--M", "1024", "Problem M dimension");
args.add_option("--N", "1024", "Problem N dimension");
args.add_option("--K", "1024", "Problem K dimension");
args.add_option("--arch", "gfx950", "GPU architecture (default: gfx950)");
if(!args.parse(argc, argv))
return 0;
std::string gfx_arch = args.get("--arch", "gfx950");
print_header("Example 07: gfx950 (CDNA4) Minimal GEMM");
// =========================================================================
// Architecture info
// =========================================================================
std::cout << "\ngfx950 (CDNA4 / MI350) highlights:\n";
std::cout << " - 160KB LDS (up from 64KB on gfx942)\n";
std::cout << " - Extended FP8 warp tiles: 16x16x128, 32x32x64\n";
std::cout << " - Packed FP4 support (pk_fp4)\n";
std::cout << " - Same warp configs as gfx942: [1,4,1], [2,2,1], [4,1,1]\n\n";
// =========================================================================
// Register kernels
// =========================================================================
std::cout << "Registering kernels for " << gfx_arch << "...\n";
Registry registry;
registry.set_name("gfx950_gemm");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
if(args.has("--list") || args.has("--list-verbose"))
{
std::cout << "\n";
print_registered_kernels(registry, std::cout, args.has("--list-verbose"));
return 0;
}
if(registry.size() == 0)
{
std::cerr << "ERROR: No kernels registered for " << gfx_arch << "!\n";
std::cerr << " Did you build with -DGPU_TARGETS=gfx950?\n";
return 1;
}
// =========================================================================
// Create Dispatcher
// =========================================================================
Dispatcher dispatcher(&registry);
// =========================================================================
// Setup Problem
// =========================================================================
const int M = args.get_int("--M", 1024);
const int N = args.get_int("--N", 1024);
const int K = args.get_int("--K", 1024);
std::cout << "\nProblem: " << M << " x " << N << " x " << K << "\n";
Problem problem(M, N, K);
using DataType = ck_tile::fp16_t;
GpuBuffer<DataType> a_dev(M * K);
GpuBuffer<DataType> b_dev(K * N);
GpuBuffer<DataType> c_dev(M * N);
std::vector<DataType> a_host(M * K, DataType(1.0f));
std::vector<DataType> b_host(K * N, DataType(1.0f));
a_dev.copy_from_host(a_host.data());
b_dev.copy_from_host(b_host.data());
c_dev.zero();
// =========================================================================
// Select and Run
// =========================================================================
auto selected = dispatcher.select_kernel(problem);
if(!selected)
{
std::cerr << "ERROR: No suitable kernel found for " << M << "x" << N << "x" << K << "\n";
return 1;
}
std::cout << " Selected: " << selected->get_name() << "\n";
float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr);
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << calculate_tflops(M, N, K, time_ms) << "\n";
// =========================================================================
// Verify
// =========================================================================
std::cout << "\nVerification:\n";
std::vector<DataType> c_host(M * N);
c_dev.copy_to_host(c_host.data());
const float expected = static_cast<float>(K);
int errors = 0;
for(int i = 0; i < std::min(M * N, 1024); ++i)
{
if(std::abs(static_cast<float>(c_host[i]) - expected) > 0.01f * expected + 1.0f)
++errors;
}
bool passed = (errors == 0);
std::cout << " Expected value: " << expected << "\n";
std::cout << " Errors (first 1024 elements): " << errors << "\n";
std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n";
print_separator();
return passed ? 0 : 1;
}

View File

@@ -29,14 +29,14 @@ cd examples
## Examples
| Example | Description | Complexity |
|---------|-------------|------------|
| [01_basic_gemm.cpp](01_basic_gemm.cpp) | Basic GEMM with declarative API, autofill, autocorrect | ★☆☆☆☆ |
| [02_multi_size.cpp](02_multi_size.cpp) | Wildcard expansion for multiple configurations | ★★☆☆☆ |
| [03_benchmark_validation.cpp](03_benchmark_validation.cpp) | Performance benchmarking with CPU reference validation | ★★☆☆☆ |
| [04_heuristics.cpp](04_heuristics.cpp) | Heuristic-based kernel selection | ★★★☆☆ |
| [05_json_export.cpp](05_json_export.cpp) | Registry JSON export for external tools | ★★☆☆☆ |
| [06_multi_registry.cpp](06_multi_registry.cpp) | Multiple registries with named kernel sets | ★★★☆☆ |
| Example | Description |
|---------|-------------|
| [01_basic_gemm.cpp](01_basic_gemm.cpp) | Basic GEMM with declarative API, autofill, autocorrect |
| [02_multi_size.cpp](02_multi_size.cpp) | Wildcard expansion for multiple configurations |
| [03_benchmark_validation.cpp](03_benchmark_validation.cpp) | Performance benchmarking with CPU reference validation |
| [04_heuristics.cpp](04_heuristics.cpp) | Heuristic-based kernel selection |
| [05_json_export.cpp](05_json_export.cpp) | Registry JSON export for external tools |
| [06_multi_registry.cpp](06_multi_registry.cpp) | Multiple registries with named kernel sets |
## Example Details
@@ -225,5 +225,5 @@ DECL_KERNEL_SET(my_kernels,
## Related Documentation
- [Python GEMM Examples](../python/README.md)
- [Convolution Examples](../../conv/cpp/README.md)
- [C++ Headers (GEMM + Grouped Conv)](../../../include/ck_tile/dispatcher/README.md)
- [Main Dispatcher README](../../../README.md)

View File

@@ -7,41 +7,37 @@
Example 01: Basic GEMM with Multiple Kernels
Demonstrates:
1. Declaring multiple kernel configurations
2. Printing all registered kernels
3. Running each kernel and validating output
1. Building a Registry with multiple kernel configurations
2. Parallel JIT compilation via registry.build()
3. Running each kernel and validating output against NumPy reference
4. Comparing performance across kernels
Complexity: ★★☆☆☆
Usage:
python3 01_basic_gemm.py
python3 01_basic_gemm.py --help
python3 01_basic_gemm.py --dtype bf16
python3 01_basic_gemm.py --size 2048
python3 01_basic_gemm.py --num-kernels 4
python3 01_basic_gemm.py --workers 4
"""
import sys
import time
import argparse
from pathlib import Path
from dataclasses import dataclass
from typing import List
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
Registry,
detect_gpu_arch,
)
@dataclass
class KernelSpec:
"""Specification for a kernel configuration"""
name: str
tile_m: int
tile_n: int
@@ -50,80 +46,37 @@ class KernelSpec:
scheduler: str = "intrawave"
# Define multiple kernel configurations to test (50+ kernels)
KERNEL_SPECS = [
# Small tiles - compv3
# Small tiles
KernelSpec("small_64x64_k32", 64, 64, 32, "compv3"),
KernelSpec("small_64x64_k64", 64, 64, 64, "compv3"),
# Small tiles - compv4
KernelSpec("small_64x64_v4_k32", 64, 64, 32, "compv4"),
KernelSpec("small_64x64_v4_k64", 64, 64, 64, "compv4"),
# Medium tiles - compv3
# Medium tiles
KernelSpec("med_128x128_k32", 128, 128, 32, "compv3"),
KernelSpec("med_128x128_k64", 128, 128, 64, "compv3"),
KernelSpec("med_128x128_k128", 128, 128, 128, "compv3"),
# Medium tiles - compv4
KernelSpec("med_128x128_v4_k32", 128, 128, 32, "compv4"),
KernelSpec("med_128x128_v4_k64", 128, 128, 64, "compv4"),
KernelSpec("med_128x128_v4_k128", 128, 128, 128, "compv4"),
# Rectangular tiles - compv3
# Rectangular tiles
KernelSpec("rect_64x128_k32", 64, 128, 32, "compv3"),
KernelSpec("rect_64x128_k64", 64, 128, 64, "compv3"),
KernelSpec("rect_128x64_k32", 128, 64, 32, "compv3"),
KernelSpec("rect_128x64_k64", 128, 64, 64, "compv3"),
# Rectangular tiles - compv4
KernelSpec("rect_64x128_v4_k32", 64, 128, 32, "compv4"),
KernelSpec("rect_64x128_v4_k64", 64, 128, 64, "compv4"),
KernelSpec("rect_128x64_v4_k32", 128, 64, 32, "compv4"),
KernelSpec("rect_128x64_v4_k64", 128, 64, 64, "compv4"),
# Large tiles - compv3
# Large tiles
KernelSpec("large_256x128_k32", 256, 128, 32, "compv3"),
KernelSpec("large_256x128_k64", 256, 128, 64, "compv3"),
KernelSpec("large_128x256_k32", 128, 256, 32, "compv3"),
KernelSpec("large_128x256_k64", 128, 256, 64, "compv3"),
KernelSpec("large_256x256_k32", 256, 256, 32, "compv3"),
KernelSpec("large_256x256_k64", 256, 256, 64, "compv3"),
# Large tiles - compv4
KernelSpec("large_256x128_v4_k32", 256, 128, 32, "compv4"),
KernelSpec("large_256x128_v4_k64", 256, 128, 64, "compv4"),
KernelSpec("large_128x256_v4_k32", 128, 256, 32, "compv4"),
KernelSpec("large_128x256_v4_k64", 128, 256, 64, "compv4"),
KernelSpec("large_256x256_v4_k32", 256, 256, 32, "compv4"),
KernelSpec("large_256x256_v4_k64", 256, 256, 64, "compv4"),
# Interwave scheduler variants
KernelSpec("int_64x64_k32", 64, 64, 32, "compv3", "interwave"),
# Interwave scheduler
KernelSpec("int_128x128_k32", 128, 128, 32, "compv3", "interwave"),
KernelSpec("int_128x128_k64", 128, 128, 64, "compv3", "interwave"),
KernelSpec("int_256x128_k32", 256, 128, 32, "compv3", "interwave"),
# More tile_k variations - compv3
KernelSpec("med_128x128_k16", 128, 128, 16, "compv3"),
KernelSpec("rect_64x128_k16", 64, 128, 16, "compv3"),
KernelSpec("rect_128x64_k16", 128, 64, 16, "compv3"),
# More tile_k variations - compv4
KernelSpec("med_128x128_v4_k16", 128, 128, 16, "compv4"),
KernelSpec("rect_64x128_v4_k16", 64, 128, 16, "compv4"),
KernelSpec("rect_128x64_v4_k16", 128, 64, 16, "compv4"),
# Additional rectangular
KernelSpec("rect_32x64_k32", 32, 64, 32, "compv3"),
KernelSpec("rect_64x32_k32", 64, 32, 32, "compv3"),
KernelSpec("rect_32x128_k32", 32, 128, 32, "compv3"),
KernelSpec("rect_128x32_k32", 128, 32, 32, "compv3"),
# Additional compv4 variants
KernelSpec("rect_32x64_v4_k32", 32, 64, 32, "compv4"),
KernelSpec("rect_64x32_v4_k32", 64, 32, 32, "compv4"),
KernelSpec("rect_32x128_v4_k32", 32, 128, 32, "compv4"),
KernelSpec("rect_128x32_v4_k32", 128, 32, 32, "compv4"),
]
def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
"""Create a KernelConfig from a spec"""
# Adjust warp tiles based on tile size
if spec.tile_m <= 64:
warp_m, warp_n = 16, 16
else:
warp_m, warp_n = 32, 32
def spec_to_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
warp_m, warp_n = (16, 16) if spec.tile_m <= 64 else (32, 32)
return KernelConfig(
dtype_a=dtype,
dtype_b=dtype,
@@ -148,180 +101,118 @@ def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfi
)
def print_kernel_table(specs: List[KernelSpec], dtype: str):
"""Print a formatted table of kernel configurations"""
print("\n" + "=" * 70)
print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)")
print("=" * 70)
print(f"\n {'#':<3} {'Name':<18} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}")
print(" " + "-" * 68)
for i, spec in enumerate(specs, 1):
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
print(
f" {i:<3} {spec.name:<18} {tile:<14} {spec.pipeline:<10} {spec.scheduler:<12}"
)
print(" " + "-" * 68)
print(f" Data type: {dtype}")
def main():
parser = argparse.ArgumentParser(
description="Basic GEMM Example with Multiple Kernels",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 01_basic_gemm.py # Default FP16 with 4 kernels
python3 01_basic_gemm.py --dtype bf16 # BF16 mode
python3 01_basic_gemm.py --size 2048 # Larger problem size
python3 01_basic_gemm.py --num-kernels 2 # Test only 2 kernels
""",
)
parser = argparse.ArgumentParser(description="Basic GEMM with Multiple Kernels")
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"])
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--size", type=int, default=512, help="Problem size MxNxK")
parser.add_argument("--num-kernels", type=int, default=0, help="0 = all")
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16", "fp32"],
help="Data type (default: fp16)",
)
parser.add_argument(
"--arch",
default="gfx942",
help="Target architecture (default: gfx942)",
)
parser.add_argument(
"--size",
type=int,
default=512,
help="Problem size MxNxK (default: 512)",
)
parser.add_argument(
"--num-kernels",
type=int,
default=0,
help="Number of kernels to test (0 = all)",
"--workers", type=int, default=0, help="Max parallel JIT workers (0 = auto)"
)
args = parser.parse_args()
reset_for_example()
print("=" * 70)
print("Example 01: Basic GEMM with Multiple Kernels")
print("=" * 70)
# Select kernels to test
specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS
# =========================================================================
# Step 1: Print all kernel configurations
# =========================================================================
print_kernel_table(specs, args.dtype)
# =========================================================================
# Step 2: Setup and test each kernel
# =========================================================================
print("\n" + "=" * 70)
print(" RUNNING KERNELS")
print("=" * 70)
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
M, N, K = args.size, args.size, args.size
results = []
print(f"\n Problem size: {M}x{N}x{K}\n")
# Step 1: Build registry
print(
f" {'#':<3} {'Name':<18} {'Tile':<14} {'Time (ms)':>10} {'TFLOPS':>10} {'Max Err':>10} {'Status':<8}"
f"\n {len(specs)} kernel configurations, dtype={args.dtype}, arch={args.arch}"
)
print(" " + "-" * 78)
for i, spec in enumerate(specs, 1):
# Create unique test data per kernel
np.random.seed(42 + i * 1000)
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
# Create config and setup dispatcher
config = create_kernel_config(spec, args.dtype, args.arch)
setup = setup_gemm_dispatcher(
config=config,
registry_name=f"kernel_{spec.name}",
verbose=False,
auto_rebuild=True,
print(f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}")
print(" " + "-" * 64)
for i, s in enumerate(specs, 1):
print(
f" {i:<3} {s.name:<22} {s.tile_m}x{s.tile_n}x{s.tile_k:<6} {s.pipeline:<10} {s.scheduler:<12}"
)
reg = Registry(name="basic_gemm")
for s in specs:
reg.register_kernel(spec_to_config(s, args.dtype, args.arch))
# Step 2: Parallel JIT build via registry.build()
workers = args.workers if args.workers > 0 else None
print(
f"\n--- Parallel JIT Build ({len(specs)} kernels{f', workers={workers}' if workers else ''}) ---"
)
t0 = time.perf_counter()
setups = reg.build(verbose=False, max_workers=workers)
jit_build_s = time.perf_counter() - t0
built = sum(1 for s in setups if s.success)
print(f" Built: {built}/{len(specs)} kernels in {jit_build_s:.1f} s")
if built == 0:
print(" ERROR: No kernels built")
return 1
# Step 3: Run each kernel and validate
print(f"\n--- Running Kernels (problem {args.size}x{args.size}x{args.size}) ---")
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
M = N = K = args.size
np.random.seed(42)
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
print(
f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Time(ms)':>10} {'TFLOPS':>10} {'MaxErr':>10} {'Status':<6}"
)
print(" " + "-" * 80)
results = []
for i, (spec, setup) in enumerate(zip(specs, setups), 1):
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
if not setup.success:
print(
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
f" {i:<3} {spec.name:<22} {tile:<14} {'---':>10} {'---':>10} {'---':>10} {'SKIP':<6}"
)
results.append((spec.name, False, 0, 0, 0))
cleanup_gemm()
results.append((spec.name, False, 0.0, 0.0, 0.0))
continue
dispatcher = setup.dispatcher
# Check if size is supported
if not dispatcher.is_supported(M, N, K):
disp = setup.dispatcher
if not disp.is_supported(M, N, K):
print(
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'SKIP':<8}"
f" {i:<3} {spec.name:<22} {tile:<14} {'---':>10} {'---':>10} {'---':>10} {'SKIP':<6}"
)
results.append((spec.name, False, 0, 0, 0))
cleanup_gemm()
results.append((spec.name, False, 0.0, 0.0, 0.0))
continue
# Run GEMM
result = dispatcher.run(A, B, M, N, K)
if not result.success:
res = disp.run(A, B, M, N, K)
if not res.success:
print(
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
f" {i:<3} {spec.name:<22} {tile:<14} {'---':>10} {'---':>10} {'---':>10} {'FAIL':<6}"
)
results.append((spec.name, False, 0, 0, 0))
cleanup_gemm()
results.append((spec.name, False, 0.0, 0.0, 0.0))
continue
# Validate against NumPy reference
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
max_err = np.max(np.abs(result.output - C_ref))
# Check if within tolerance
passed = max_err < 1e-2
status = "PASS" if passed else "FAIL"
max_err = float(np.max(np.abs(res.output - C_ref)))
ok = max_err < 1e-2
tag = "PASS" if ok else "FAIL"
print(
f" {i:<3} {spec.name:<18} {tile:<14} {result.time_ms:>10.4f} {result.tflops:>10.2f} {max_err:>10.2e} {status:<8}"
f" {i:<3} {spec.name:<22} {tile:<14} {res.time_ms:>10.4f} {res.tflops:>10.2f} {max_err:>10.2e} {tag:<6}"
)
results.append((spec.name, passed, result.time_ms, result.tflops, max_err))
cleanup_gemm()
# =========================================================================
# Step 3: Summary
# =========================================================================
print("\n" + "=" * 70)
print(" SUMMARY")
print("=" * 70)
results.append((spec.name, ok, res.time_ms, res.tflops, max_err))
# Step 4: Summary
passed = sum(1 for r in results if r[1])
failed = len(results) - passed
valid = [r for r in results if r[1]]
print(f"\n Results: {passed}/{len(results)} kernels passed")
print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}")
if results:
valid_results = [r for r in results if r[1]]
if valid_results:
best = max(valid_results, key=lambda x: x[3])
print(f"\n Best kernel: {best[0]} ({best[3]:.2f} TFLOPS)")
if failed == 0:
print("\n *** ALL KERNELS PASSED ***")
else:
print(f"\n *** {failed} KERNELS FAILED ***")
print("\n" + "=" * 70)
print(f" Results: {passed}/{len(results)} passed")
print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}")
print(f" JIT time: {jit_build_s:.1f} s (parallel)")
if valid:
best = max(valid, key=lambda x: x[3])
print(f" Best: {best[0]} ({best[3]:.2f} TFLOPS)")
print(f" Status: {'PASS' if failed == 0 else 'FAIL'}")
print("=" * 70)
return 0 if failed == 0 else 1

View File

@@ -6,9 +6,7 @@
"""
Example 02: Batch GEMM
Runs multiple GEMM operations with different sizes.
Complexity: ★★☆☆☆
Runs multiple GEMM operations with different sizes using JIT compilation.
Usage:
python3 02_batch_gemm.py
@@ -25,9 +23,8 @@ import numpy as np
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
Registry,
detect_gpu_arch,
)
@@ -55,20 +52,20 @@ Examples:
help="Maximum problem size (default: 4096)",
)
parser.add_argument(
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
"--arch",
default=detect_gpu_arch(),
help="Target architecture (auto-detected from rocminfo)",
)
args = parser.parse_args()
reset_for_example()
print("=" * 60)
print("Example 02: Batch GEMM")
print("=" * 60)
# =========================================================================
# Step 1: Setup dispatcher
# Step 1: JIT build dispatcher
# =========================================================================
print("\nStep 1: Setup Dispatcher")
print("\nStep 1: JIT Build Dispatcher")
config = KernelConfig(
dtype_a=args.dtype,
@@ -80,19 +77,22 @@ Examples:
gfx_arch=args.arch,
)
setup = setup_gemm_dispatcher(config, registry_name="batch_gemm", verbose=True)
if not setup.success:
print(f" ERROR: {setup.error}")
reg = Registry(name="batch_gemm")
reg.register_kernel(config)
setups = reg.build(verbose=True)
if not setups or not setups[0].success:
error = setups[0].error if setups else "No kernels built"
print(f" ERROR: {error}")
return 1
dispatcher = setup.dispatcher
dispatcher = setups[0].dispatcher
# =========================================================================
# Step 2: Run batch of different sizes
# =========================================================================
print("\nStep 2: Run Batch")
# Generate sizes up to max_size
all_sizes = [
(256, 256, 256),
(512, 512, 512),
@@ -135,9 +135,6 @@ Examples:
avg_tflops = (total_ops / 1e12) / (total_time / 1000)
print(f"\n Total: {total_time:.2f} ms, Average: {avg_tflops:.2f} TFLOPS")
# Cleanup
cleanup_gemm()
print("\n" + "=" * 60)
print("Batch GEMM complete!")
print("=" * 60)

View File

@@ -6,9 +6,8 @@
"""
Example 03: Benchmark
Performance benchmarking with compute-optimized kernel configuration.
Complexity: ★★★☆☆
Performance benchmarking with compute-optimized kernel configuration
using JIT compilation.
Usage:
python3 03_benchmark.py
@@ -26,9 +25,8 @@ import numpy as np
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
Registry,
detect_gpu_arch,
)
@@ -63,20 +61,20 @@ Examples:
"--iterations", type=int, default=10, help="Benchmark iterations (default: 10)"
)
parser.add_argument(
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
"--arch",
default=detect_gpu_arch(),
help="Target architecture (auto-detected from rocminfo)",
)
args = parser.parse_args()
reset_for_example()
print("=" * 60)
print("Example 03: Benchmark")
print("=" * 60)
# =========================================================================
# Step 1: Setup dispatcher with compute-optimized config
# Step 1: JIT build dispatcher with compute-optimized config
# =========================================================================
print("\nStep 1: Setup Dispatcher")
print("\nStep 1: JIT Build Dispatcher")
config = KernelConfig(
dtype_a=args.dtype,
@@ -90,12 +88,16 @@ Examples:
gfx_arch=args.arch,
)
setup = setup_gemm_dispatcher(config, registry_name="benchmark", verbose=True)
if not setup.success:
print(f" ERROR: {setup.error}")
reg = Registry(name="benchmark")
reg.register_kernel(config)
setups = reg.build(verbose=True)
if not setups or not setups[0].success:
error = setups[0].error if setups else "No kernels built"
print(f" ERROR: {error}")
return 1
dispatcher = setup.dispatcher
dispatcher = setups[0].dispatcher
# =========================================================================
# Step 2: Benchmark
@@ -130,11 +132,9 @@ Examples:
A = np.random.randn(M, K).astype(np_dtype) * 0.1
B = np.random.randn(K, N).astype(np_dtype) * 0.1
# Warmup
for _ in range(args.warmup):
dispatcher.run(A, B, M, N, K)
# Benchmark
times = []
for _ in range(args.iterations):
result = dispatcher.run(A, B, M, N, K)
@@ -150,9 +150,6 @@ Examples:
f" {M:>4}x{N:>4}x{K:<4} | {min_time:>10.4f} | {avg_time:>10.4f} | {tflops:>10.2f}"
)
# Cleanup
cleanup_gemm()
# Summary
print("\n" + "=" * 60)
print("Summary")

View File

@@ -6,9 +6,7 @@
"""
Example 04: Validation
Validates GPU GEMM against NumPy reference.
Complexity: ★★★☆☆
Validates GPU GEMM against NumPy reference using JIT compilation.
Usage:
python3 04_validation.py
@@ -26,9 +24,8 @@ import numpy as np
from ctypes_utils import (
KernelConfig,
Validator,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
Registry,
detect_gpu_arch,
)
@@ -56,20 +53,20 @@ Examples:
"--atol", type=float, default=1e-2, help="Absolute tolerance (default: 1e-2)"
)
parser.add_argument(
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
"--arch",
default=detect_gpu_arch(),
help="Target architecture (auto-detected from rocminfo)",
)
args = parser.parse_args()
reset_for_example()
print("=" * 60)
print("Example 04: Validation")
print("=" * 60)
# =========================================================================
# Step 1: Setup dispatcher
# Step 1: JIT build dispatcher
# =========================================================================
print("\nStep 1: Setup Dispatcher")
print("\nStep 1: JIT Build Dispatcher")
config = KernelConfig(
dtype_a=args.dtype,
@@ -81,12 +78,16 @@ Examples:
gfx_arch=args.arch,
)
setup = setup_gemm_dispatcher(config, registry_name="validation", verbose=True)
if not setup.success:
print(f" ERROR: {setup.error}")
reg = Registry(name="validation")
reg.register_kernel(config)
setups = reg.build(verbose=True)
if not setups or not setups[0].success:
error = setups[0].error if setups else "No kernels built"
print(f" ERROR: {error}")
return 1
dispatcher = setup.dispatcher
dispatcher = setups[0].dispatcher
# =========================================================================
# Step 2: Run validation tests
@@ -139,9 +140,6 @@ Examples:
print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | FAILED")
failed += 1
# Cleanup
cleanup_gemm()
# Summary
print("\n" + "=" * 60)
total = passed + failed

View File

@@ -8,7 +8,6 @@ Example 05: NumPy Integration
Shows how to create a GPU-accelerated matmul wrapper.
Complexity: ★★☆☆☆
Usage:
python3 05_numpy_integration.py
@@ -29,6 +28,7 @@ from ctypes_utils import (
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
detect_gpu_arch,
)
@@ -70,7 +70,9 @@ Examples:
help="Data type (default: fp16)",
)
parser.add_argument(
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
"--arch",
default=detect_gpu_arch(),
help="Target architecture (auto-detected from rocminfo)",
)
args = parser.parse_args()

View File

@@ -8,7 +8,6 @@ Example 06: JSON Export
Exports registry configuration to JSON.
Complexity: ★★☆☆☆
Usage:
python3 06_json_export.py
@@ -28,6 +27,7 @@ from ctypes_utils import (
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
detect_gpu_arch,
)
@@ -54,7 +54,9 @@ Examples:
help="Data type (default: fp16)",
)
parser.add_argument(
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
"--arch",
default=detect_gpu_arch(),
help="Target architecture (auto-detected from rocminfo)",
)
args = parser.parse_args()

View File

@@ -18,7 +18,6 @@ This tests:
- Multiple data types (fp16, bf16)
- Different schedulers (intrawave, interwave)
Complexity: ★★★★☆
Usage:
python3 07_stress_test.py
@@ -43,6 +42,7 @@ from ctypes_utils import (
cleanup_gemm,
reset_for_example,
Validator,
detect_gpu_arch,
)
@@ -413,8 +413,8 @@ Examples:
)
parser.add_argument(
"--arch",
default="gfx942",
help="Target architecture (default: gfx942)",
default=detect_gpu_arch(),
help="Target architecture (auto-detected from rocminfo, override with --arch gfxNNN)",
)
args = parser.parse_args()

View File

@@ -19,7 +19,6 @@ Heuristic strategies:
- Memory-bound: Optimize memory access for bandwidth-limited cases
- Latency-focused: Minimize kernel launch overhead for small problems
Complexity: ★★★★☆
Usage:
python3 08_heuristics.py
@@ -43,6 +42,7 @@ from ctypes_utils import (
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
detect_gpu_arch,
)
@@ -561,8 +561,8 @@ Examples:
)
parser.add_argument(
"--arch",
default="gfx942",
help="Target architecture (default: gfx942)",
default=detect_gpu_arch(),
help="Target architecture (auto-detected from rocminfo, override with --arch gfxNNN)",
)
args = parser.parse_args()

View File

@@ -8,7 +8,6 @@ Example 09: Multiple Registries
Demonstrates multiple registries for different optimization targets.
Complexity: ★★★★★
Usage:
python3 09_multi_registry.py
@@ -30,6 +29,7 @@ from ctypes_utils import (
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
detect_gpu_arch,
)
@@ -50,7 +50,9 @@ Examples:
help="Data type (default: fp16)",
)
parser.add_argument(
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
"--arch",
default=detect_gpu_arch(),
help="Target architecture (auto-detected from rocminfo)",
)
args = parser.parse_args()

View File

@@ -33,6 +33,7 @@ from ctypes_utils import (
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
detect_gpu_arch,
)
@@ -69,7 +70,11 @@ def parse_args():
# Kernel configuration
parser.add_argument("--dtype", default="fp16", help="Data type")
parser.add_argument("--pipeline", default="compv4", help="Pipeline type")
parser.add_argument("--arch", default="gfx942", help="GPU architecture")
parser.add_argument(
"--arch",
default=detect_gpu_arch(),
help="GPU architecture (auto-detected from rocminfo)",
)
return parser.parse_args()

View File

@@ -15,7 +15,6 @@ Key Features:
- Use arch_filter validation on loaded configs
- Export to C++ DECL_KERNEL_SET format
Complexity: ★★★☆☆
Usage:
python3 11_json_import.py
@@ -45,6 +44,7 @@ from ctypes_utils import ( # noqa: E402
cleanup_gemm,
reset_for_example,
validate_kernel_config,
detect_gpu_arch,
)
# Sample JSON configuration (embedded for demonstration)
@@ -141,8 +141,8 @@ Examples:
)
parser.add_argument(
"--arch",
default="gfx942",
help="Target GPU architecture (default: gfx942)",
default=detect_gpu_arch(),
help="Target GPU architecture (auto-detected from rocminfo, override with --arch gfxNNN)",
)
args = parser.parse_args()
@@ -236,13 +236,13 @@ Examples:
else:
invalid_count += 1
if invalid_count <= 3: # Show first 3 invalid
print(f"\n Invalid: {config.kernel_name()}")
print(f"\n FAIL Invalid: {config.kernel_name()}")
for error in result.errors:
print(f" Error: {error}")
print("\n Validation Summary:")
print(f" Valid: {valid_count}")
print(f" Invalid: {invalid_count}")
print(f" OK Valid: {valid_count}")
print(f" FAIL Invalid: {invalid_count}")
print(f" Total: {len(configs)}")
# =========================================================================
@@ -275,12 +275,12 @@ Examples:
disp_config, registry_name="json_import", verbose=False
)
if setup.success:
print(" Dispatcher setup successful")
print(" OK Dispatcher setup successful")
print(
f" Kernel header: {setup.kernel_header.name if setup.kernel_header else 'N/A'}"
)
else:
print(f" Dispatcher setup: {setup.error}")
print(f" WARNING Dispatcher setup: {setup.error}")
print(" (This is expected if kernels aren't generated)")
# =========================================================================

View File

@@ -295,5 +295,5 @@ Compilation time scales roughly linearly with kernel count.
## Related Documentation
- [C++ GEMM Examples](../cpp/README.md)
- [Python Conv Examples](../../conv/python/README.md)
- [Python Utilities](../../../python/README.md)
- [Main Dispatcher README](../../../README.md)

View File

@@ -0,0 +1,203 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Example 01: Basic Grouped Convolution
//
// Demonstrates three declaration patterns (mirrors GEMM 01):
// 1. AUTOFILL - tile + pipeline only, wave/warp auto-filled
// 2. AUTOCORRECT - invalid wave(1,1,1) corrected to valid config
// 3. FULL - all parameters explicit (matches validated gfx942 config)
//
// Then runs the forward convolution on GPU and verifies output.
//
// Build: cd dispatcher/build && cmake .. && make grouped_conv_01_basic
#include <hip/hip_runtime.h>
#include <iostream>
#include <iomanip>
#include <vector>
#include <cmath>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::grouped_conv_utils;
using GroupedConvSig = grouped_conv_decl::GroupedConvSignature;
using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm;
using InDataType = ck_tile::half_t;
using WeiDataType = ck_tile::half_t;
using OutDataType = ck_tile::half_t;
// Three declaration patterns -- codegen auto-fills/auto-corrects as needed
DECL_GROUPED_CONV_KERNEL_SET(
basic_conv_kernels,
// Pattern 1: AUTOFILL - only tile + pipeline, rest auto-filled
.add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").scheduler("intrawave"),
"gfx950")
// Pattern 2: AUTOCORRECT - wave(1,1,1) invalid, corrected to (1,4,1)
.add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
GroupedConvAlgo()
.tile(1, 64, 64)
.wave(1, 1, 1)
.warp(16, 16, 32)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle")
.vector_sizes(4, 8, 8),
"gfx950")
// Pattern 3: FULL - all parameters explicit (validated config)
.add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
GroupedConvAlgo()
.tile(1, 128, 128)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle")
.vector_sizes(4, 8, 8)
.block_per_cu(1),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 01: Basic Grouped Convolution",
"Declaration patterns + GPU execution");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--size", "14", "Spatial size (H=W)");
args.add_option("-n", "1", "Batch size");
args.add_option("-g", "1", "Groups");
args.add_option("-c", "64", "Input channels C");
args.add_option("-k", "128", "Output channels K");
if(!args.parse(argc, argv))
return 0;
utils::print_header("Example 01: Basic Grouped Convolution");
std::string gfx_arch = args.get("--arch", "gfx950");
int N = args.get_int("-n", 1);
int G = args.get_int("-g", 1);
int C = args.get_int("-c", 64);
int K = args.get_int("-k", 128);
int HW = args.get_int("--size", 14);
int Y = 3, X = 3;
// Step 1: Show declared kernel sets
std::cout << "\nStep 1: Declared Kernel Sets\n";
GroupedConvKernelSetRegistry::instance().print();
// Step 2: Register kernels
std::cout << "\nStep 2: Register Kernels\n";
GroupedConvRegistry registry;
registry.set_name("basic_conv");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
// Step 3: Create dispatcher
std::cout << "\nStep 3: Create Dispatcher\n";
GroupedConvDispatcher dispatcher(&registry);
// Step 4: Build problem using CK Tile ConvParam
std::cout << "\nStep 4: Problem\n";
auto problem = create_grouped_conv2d_problem(N, C, K, HW, HW, Y, X, 1, 1);
problem.op = GroupedConvOp::Forward;
print_grouped_conv_problem(problem);
ck_tile::conv::ConvParam conv_param{
2,
static_cast<ck_tile::index_t>(G),
static_cast<ck_tile::index_t>(N),
static_cast<ck_tile::index_t>(K),
static_cast<ck_tile::index_t>(C),
{static_cast<ck_tile::index_t>(Y), static_cast<ck_tile::index_t>(X)},
{static_cast<ck_tile::index_t>(HW), static_cast<ck_tile::index_t>(HW)},
{1, 1},
{1, 1},
{1, 1},
{1, 1}};
using InLayout = ck_tile::tensor_layout::convolution::NHWGC;
using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC;
using OutLayout = ck_tile::tensor_layout::convolution::NHWGK;
auto in_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
auto wei_desc =
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
auto out_desc =
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
ck_tile::HostTensor<InDataType> input_host(in_desc);
ck_tile::HostTensor<WeiDataType> weight_host(wei_desc);
ck_tile::HostTensor<OutDataType> output_host(out_desc);
ck_tile::FillUniformDistribution<InDataType>{-0.5f, 0.5f}(input_host);
ck_tile::FillUniformDistribution<WeiDataType>{-0.5f, 0.5f}(weight_host);
ck_tile::DeviceMem input_dev(input_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem weight_dev(weight_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem output_dev(output_host.get_element_space_size_in_bytes());
input_dev.ToDevice(input_host.data());
weight_dev.ToDevice(weight_host.data());
// Step 5: Select and run
std::cout << "\nStep 5: Select and Run\n";
auto* selected = dispatcher.select_kernel(problem);
if(!selected)
{
std::cerr << " ERROR: No kernel found for problem!\n";
return 1;
}
std::cout << " Selected: " << selected->name() << "\n";
float time_ms = dispatcher.run(input_dev.GetDeviceBuffer(),
weight_dev.GetDeviceBuffer(),
output_dev.GetDeviceBuffer(),
problem,
nullptr);
double tflops = calculate_conv_tflops(problem, time_ms);
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
// Step 6: Verify
std::cout << "\nStep 6: Verify\n";
output_dev.FromDevice(output_host.data());
size_t total = output_host.get_element_space_size();
size_t nonzero = 0;
double checksum = 0.0;
for(size_t i = 0; i < total; ++i)
{
float v = static_cast<float>(output_host.data()[i]);
if(v != 0.0f)
++nonzero;
checksum += v;
}
bool passed = nonzero > 0;
std::cout << " Output elements: " << total << "\n";
std::cout << " Non-zero: " << nonzero << "/" << total
<< (nonzero > 0 ? " (kernel produced output)" : " WARNING: all zeros!") << "\n";
std::cout << " Checksum: " << std::fixed << std::setprecision(2) << checksum << "\n";
std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n";
utils::print_separator();
std::cout << "DECLARATION PATTERNS:\n";
std::cout << " 1. AUTOFILL: tile + pipeline only, wave/warp auto-filled\n";
std::cout << " 2. AUTOCORRECT: invalid wave(1,1,1) corrected\n";
std::cout << " 3. FULL: all parameters explicit\n";
utils::print_separator();
return passed ? 0 : 1;
}

View File

@@ -0,0 +1,216 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Example 02: All Convolution Directions
//
// Forward, backward-data, and backward-weight for 2D convolution,
// each executed on GPU with non-zero verification.
//
// Build: cd dispatcher/build && cmake .. && make grouped_conv_02_all_dirs
#include <hip/hip_runtime.h>
#include <iostream>
#include <iomanip>
#include <vector>
#include <cmath>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::grouped_conv_utils;
using GroupedConvSig = grouped_conv_decl::GroupedConvSignature;
using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm;
using InDataType = ck_tile::half_t;
using WeiDataType = ck_tile::half_t;
using OutDataType = ck_tile::half_t;
DECL_GROUPED_CONV_KERNEL_SET(
conv_fwd_2d,
.add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").vector_sizes(4, 8, 8),
"gfx950"));
DECL_GROUPED_CONV_KERNEL_SET(
conv_bwdd_2d,
.add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2),
GroupedConvAlgo().tile(1, 128, 128).pipeline("compv3").vector_sizes(4, 8, 8),
"gfx950"));
DECL_GROUPED_CONV_KERNEL_SET(
conv_bwdw_2d,
.add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_weight").dims(2),
GroupedConvAlgo()
.tile(1, 128, 128)
.pipeline("compv3")
.memory_op("atomic_add")
.vector_sizes(4, 8, 8),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 02: All Convolution Directions",
"Forward/BwdData/BwdWeight with GPU execution and verification");
args.add_option("--arch", "gfx950", "GPU architecture");
if(!args.parse(argc, argv))
return 0;
utils::print_header("Example 02: All Convolution Directions");
std::string gfx_arch = args.get("--arch", "gfx950");
GroupedConvRegistry registry;
registry.set_name("all_directions");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
GroupedConvDispatcher dispatcher(&registry);
const int N = 1, G = 1, C = 64, K = 128, Hi = 14, Wi = 14, Y = 3, X = 3;
ck_tile::conv::ConvParam conv_param{
2,
static_cast<ck_tile::index_t>(G),
static_cast<ck_tile::index_t>(N),
static_cast<ck_tile::index_t>(K),
static_cast<ck_tile::index_t>(C),
{static_cast<ck_tile::index_t>(Y), static_cast<ck_tile::index_t>(X)},
{static_cast<ck_tile::index_t>(Hi), static_cast<ck_tile::index_t>(Wi)},
{1, 1},
{1, 1},
{1, 1},
{1, 1}};
using InLayout = ck_tile::tensor_layout::convolution::NHWGC;
using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC;
using OutLayout = ck_tile::tensor_layout::convolution::NHWGK;
auto in_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
auto wei_desc =
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
auto out_desc =
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
ck_tile::HostTensor<InDataType> input(in_desc);
ck_tile::HostTensor<WeiDataType> weight(wei_desc);
ck_tile::HostTensor<OutDataType> output(out_desc);
ck_tile::FillUniformDistribution<InDataType>{-0.5f, 0.5f}(input);
ck_tile::FillUniformDistribution<WeiDataType>{-0.5f, 0.5f}(weight);
ck_tile::DeviceMem input_dev(input.get_element_space_size_in_bytes());
ck_tile::DeviceMem weight_dev(weight.get_element_space_size_in_bytes());
ck_tile::DeviceMem output_dev(output.get_element_space_size_in_bytes());
input_dev.ToDevice(input.data());
weight_dev.ToDevice(weight.data());
std::cout << "\n " << std::left << std::setw(12) << "Direction" << std::right << std::setw(10)
<< "Time(ms)" << std::setw(10) << "TFLOPS" << std::setw(14) << "NonZero"
<< std::setw(10) << "Status" << "\n";
std::cout << " " << std::string(56, '-') << "\n";
bool all_pass = true;
auto print_result =
[](const char* label, float time_ms, double tflops, size_t nz, size_t total, bool ok) {
std::cout << " " << std::left << std::setw(12) << label << std::right << std::fixed
<< std::setprecision(4) << std::setw(10) << time_ms << std::setprecision(2)
<< std::setw(10) << tflops << std::setw(14)
<< (std::to_string(nz) + "/" + std::to_string(total)) << std::setw(10)
<< (ok ? "OK" : "FAIL") << "\n";
};
// Forward: run(X, W, Y)
{
auto problem =
create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::Forward);
float time_ms = dispatcher.run(input_dev.GetDeviceBuffer(),
weight_dev.GetDeviceBuffer(),
output_dev.GetDeviceBuffer(),
problem,
nullptr);
output_dev.FromDevice(output.data());
size_t nz = 0;
for(size_t i = 0; i < output.get_element_space_size(); ++i)
if(static_cast<float>(output.data()[i]) != 0.0f)
++nz;
bool ok = nz > 0;
print_result("forward",
time_ms,
calculate_conv_tflops(problem, time_ms),
nz,
output.get_element_space_size(),
ok);
if(!ok)
all_pass = false;
}
// Backward Data: run(dY, W, dX)
{
auto problem =
create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardData);
ck_tile::HostTensor<InDataType> dx_host(in_desc);
ck_tile::DeviceMem dx_dev(dx_host.get_element_space_size_in_bytes());
float time_ms = dispatcher.run(output_dev.GetDeviceBuffer(), // dY (from forward pass)
weight_dev.GetDeviceBuffer(), // W
dx_dev.GetDeviceBuffer(), // dX (output)
problem,
nullptr);
dx_dev.FromDevice(dx_host.data());
size_t nz = 0;
for(size_t i = 0; i < dx_host.get_element_space_size(); ++i)
if(static_cast<float>(dx_host.data()[i]) != 0.0f)
++nz;
bool ok = nz > 0;
print_result("bwd_data",
time_ms,
calculate_conv_tflops(problem, time_ms),
nz,
dx_host.get_element_space_size(),
ok);
if(!ok)
all_pass = false;
}
// Backward Weight: run(X, dY, dW)
{
auto problem = create_grouped_conv2d_problem(
N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardWeight);
ck_tile::HostTensor<WeiDataType> dw_host(wei_desc);
ck_tile::DeviceMem dw_dev(dw_host.get_element_space_size_in_bytes());
float time_ms = dispatcher.run(input_dev.GetDeviceBuffer(), // X
output_dev.GetDeviceBuffer(), // dY
dw_dev.GetDeviceBuffer(), // dW (output)
problem,
nullptr);
dw_dev.FromDevice(dw_host.data());
size_t nz = 0;
for(size_t i = 0; i < dw_host.get_element_space_size(); ++i)
if(static_cast<float>(dw_host.data()[i]) != 0.0f)
++nz;
bool ok = nz > 0;
print_result("bwd_weight",
time_ms,
calculate_conv_tflops(problem, time_ms),
nz,
dw_host.get_element_space_size(),
ok);
if(!ok)
all_pass = false;
}
utils::print_separator();
std::cout << " Status: " << (all_pass ? "PASS" : "FAIL") << "\n";
utils::print_separator();
return all_pass ? 0 : 1;
}

View File

@@ -0,0 +1,263 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Example 03: Benchmark and CPU-Reference Validation
//
// Runs a 2D grouped conv forward kernel on the GPU via dispatcher.run()
// and compares against the CK Tile host reference implementation.
// Exposes warmup/repeat/log_level as CLI args (matches example 20 pattern).
//
// Build: cd dispatcher/build && cmake .. && make grouped_conv_03_bench_val
#include <hip/hip_runtime.h>
#include <iostream>
#include <iomanip>
#include <vector>
#include <cmath>
#include <algorithm>
#include <numeric>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/host/reference/reference_grouped_conv_fwd.hpp"
#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::grouped_conv_utils;
using GroupedConvSig = grouped_conv_decl::GroupedConvSignature;
using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm;
using InDataType = ck_tile::half_t;
using WeiDataType = ck_tile::half_t;
using OutDataType = ck_tile::half_t;
using AccDataType = float;
DECL_GROUPED_CONV_KERNEL_SET(
bench_kernels,
.add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").vector_sizes(4, 8, 8),
"gfx950")
.add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
GroupedConvAlgo().tile(1, 64, 64).pipeline("compv3").vector_sizes(4, 8, 8),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 03: Benchmark & Validation",
"GPU execution with CPU reference validation");
args.add_option("-n", "1", "Batch size N");
args.add_option("-g", "1", "Groups G");
args.add_option("-c", "64", "Input channels C");
args.add_option("-k", "128", "Output channels K");
args.add_option("--size", "14", "Spatial size (H=W)");
args.add_option("--warmup", "3", "Warmup iterations");
args.add_option("--repeat", "10", "Benchmark iterations");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_flag("--no-verify", "Skip CPU validation");
if(!args.parse(argc, argv))
return 0;
utils::print_header("Example 03: Grouped Conv Benchmark & Validation");
int N = args.get_int("-n", 1);
int G = args.get_int("-g", 1);
int C = args.get_int("-c", 64);
int K = args.get_int("-k", 128);
int Hi = args.get_int("--size", 14);
int Wi = Hi;
int Y = 3, X = 3;
int warmup = args.get_int("--warmup", 3);
int repeat = args.get_int("--repeat", 10);
bool verify = !args.has("--no-verify");
std::string gfx_arch = args.get("--arch", "gfx950");
std::cout << "\nProblem: N=" << N << " G=" << G << " C=" << C << " K=" << K << " Hi=" << Hi
<< " Wi=" << Wi << " Y=" << Y << " X=" << X << "\n";
std::cout << "Benchmark: warmup=" << warmup << " repeat=" << repeat << "\n";
// Step 1: Setup tensors using CK Tile descriptors
std::cout << "\nStep 1: Setup tensors\n";
ck_tile::conv::ConvParam conv_param{
2,
static_cast<ck_tile::index_t>(G),
static_cast<ck_tile::index_t>(N),
static_cast<ck_tile::index_t>(K),
static_cast<ck_tile::index_t>(C),
{static_cast<ck_tile::index_t>(Y), static_cast<ck_tile::index_t>(X)},
{static_cast<ck_tile::index_t>(Hi), static_cast<ck_tile::index_t>(Wi)},
{1, 1},
{1, 1},
{1, 1},
{1, 1}};
using InLayout = ck_tile::tensor_layout::convolution::NHWGC;
using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC;
using OutLayout = ck_tile::tensor_layout::convolution::NHWGK;
auto in_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
auto wei_desc =
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
auto out_desc =
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
ck_tile::HostTensor<InDataType> input(in_desc);
ck_tile::HostTensor<WeiDataType> weight(wei_desc);
ck_tile::HostTensor<OutDataType> output_gpu(out_desc);
ck_tile::HostTensor<OutDataType> output_cpu(out_desc);
ck_tile::FillUniformDistribution<InDataType>{-0.5f, 0.5f}(input);
ck_tile::FillUniformDistribution<WeiDataType>{-0.5f, 0.5f}(weight);
output_cpu.SetZero();
std::cout << " Input: " << input.get_element_space_size() << " elements\n";
std::cout << " Weight: " << weight.get_element_space_size() << " elements\n";
std::cout << " Output: " << output_gpu.get_element_space_size() << " elements\n";
// Step 2: CPU reference
if(verify)
{
std::cout << "\nStep 2: CPU Reference\n";
std::vector<ck_tile::long_index_t> strides_v = {1, 1};
std::vector<ck_tile::long_index_t> dilations_v = {1, 1};
std::vector<ck_tile::long_index_t> left_pads_v = {1, 1};
std::vector<ck_tile::long_index_t> right_pads_v = {1, 1};
ck_tile::reference_grouped_conv_fwd<2, InDataType, WeiDataType, OutDataType>(
input, weight, output_cpu, strides_v, dilations_v, left_pads_v, right_pads_v);
std::cout << " CPU ref[0..7]: ";
for(int i = 0; i < std::min(8, static_cast<int>(output_cpu.get_element_space_size())); ++i)
std::cout << std::fixed << std::setprecision(4)
<< static_cast<float>(output_cpu.data()[i]) << " ";
std::cout << "\n";
}
// Step 3: GPU execution via dispatcher
std::cout << "\nStep 3: GPU Execution (via dispatcher.run)\n";
GroupedConvRegistry registry;
registry.set_name("bench_val");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
GroupedConvDispatcher dispatcher(&registry);
auto problem = create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1);
problem.op = GroupedConvOp::Forward;
auto* selected = dispatcher.select_kernel(problem);
if(!selected)
{
std::cerr << " ERROR: No kernel found!\n";
return 1;
}
std::cout << " Selected: " << selected->name() << "\n";
ck_tile::DeviceMem input_dev(input.get_element_space_size_in_bytes());
ck_tile::DeviceMem weight_dev(weight.get_element_space_size_in_bytes());
ck_tile::DeviceMem output_dev(output_gpu.get_element_space_size_in_bytes());
input_dev.ToDevice(input.data());
weight_dev.ToDevice(weight.data());
float elapsed_ms = dispatcher.run(input_dev.GetDeviceBuffer(),
weight_dev.GetDeviceBuffer(),
output_dev.GetDeviceBuffer(),
problem,
nullptr);
output_dev.FromDevice(output_gpu.data());
size_t total = output_gpu.get_element_space_size();
std::cout << " GPU out[0..7]: ";
for(int i = 0; i < std::min(8, static_cast<int>(total)); ++i)
std::cout << std::fixed << std::setprecision(4) << static_cast<float>(output_gpu.data()[i])
<< " ";
std::cout << "\n";
size_t nonzero_gpu = 0;
double gpu_sum = 0.0;
for(size_t i = 0; i < total; ++i)
{
float v = static_cast<float>(output_gpu.data()[i]);
if(v != 0.0f)
++nonzero_gpu;
gpu_sum += v;
}
std::cout << " GPU checksum: " << std::fixed << std::setprecision(6) << gpu_sum << "\n";
std::cout << " GPU non-zero: " << nonzero_gpu << "/" << total
<< (nonzero_gpu > 0 ? " (kernel produced output)" : " WARNING: all zeros!") << "\n";
int Ho = static_cast<int>(problem.Ho());
int Wo = static_cast<int>(problem.Wo());
double flops = 2.0 * G * N * K * C * Y * X * Ho * Wo;
double tflops = flops / (elapsed_ms * 1e9);
std::cout << " Time: " << std::fixed << std::setprecision(4) << elapsed_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
// Step 4: Validation
bool passed = true;
if(verify)
{
std::cout << "\nStep 4: Validation (GPU vs CPU)\n";
constexpr float rtol = 1e-2f;
constexpr float atol = 1e-2f;
float max_diff = 0.0f;
float max_rel = 0.0f;
size_t max_diff_idx = 0;
size_t num_elements = output_gpu.get_element_space_size();
size_t mismatches = 0;
for(size_t i = 0; i < num_elements; ++i)
{
float gpu_val = static_cast<float>(output_gpu.data()[i]);
float cpu_val = static_cast<float>(output_cpu.data()[i]);
float diff = std::abs(gpu_val - cpu_val);
float tol = atol + rtol * std::abs(cpu_val);
float rel = diff / (std::abs(cpu_val) + 1e-6f);
if(diff > max_diff)
{
max_diff = diff;
max_diff_idx = i;
}
max_rel = std::max(max_rel, rel);
if(diff > tol)
++mismatches;
}
passed = (mismatches == 0);
std::cout << " Side-by-side at worst element [" << max_diff_idx << "]:\n";
std::cout << " GPU: " << std::fixed << std::setprecision(6)
<< static_cast<float>(output_gpu.data()[max_diff_idx])
<< " CPU: " << static_cast<float>(output_cpu.data()[max_diff_idx])
<< " diff: " << std::scientific << max_diff << "\n";
std::cout << " Elements: " << num_elements << "\n";
std::cout << " Mismatches: " << mismatches << "/" << num_elements << "\n";
std::cout << " Max abs diff: " << std::scientific << max_diff << "\n";
std::cout << " Max rel diff: " << std::scientific << max_rel << "\n";
std::cout << " Status: " << (passed ? "PASSED" : "FAILED") << "\n";
}
utils::print_separator();
std::cout << "BENCHMARK & VALIDATION:\n";
std::cout << " GPU kernel: " << (selected ? selected->name() : "none") << "\n";
std::cout << " Performance: " << std::fixed << std::setprecision(2) << tflops
<< " TFLOPS\n";
std::cout << " CPU reference: reference_grouped_conv_fwd<2>()\n";
std::cout << " Validation: " << (passed ? "PASS" : "FAIL") << "\n";
utils::print_separator();
return passed ? 0 : 1;
}

View File

@@ -0,0 +1,154 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Example 04: Heuristic Selection + JSON Export
//
// Demonstrates runtime kernel selection with heuristic ranking,
// GPU execution, and JSON registry export.
//
// Build: cd dispatcher/build && cmake .. && make grouped_conv_04_registry_json
#include <hip/hip_runtime.h>
#include <iostream>
#include <iomanip>
#include <vector>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::grouped_conv_utils;
using GroupedConvSig = grouped_conv_decl::GroupedConvSignature;
using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm;
using InDataType = ck_tile::half_t;
using WeiDataType = ck_tile::half_t;
using OutDataType = ck_tile::half_t;
// Two tile configs for heuristic selection
DECL_GROUPED_CONV_KERNEL_SET(
heuristic_kernels,
.add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").vector_sizes(4, 8, 8),
"gfx950")
.add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
GroupedConvAlgo().tile(1, 64, 64).pipeline("compv3").vector_sizes(4, 8, 8),
"gfx950"));
std::vector<std::string> conv_heuristic(const GroupedConvProblem& problem)
{
int64_t spatial = problem.Ho() * problem.Wo();
if(spatial > 400)
return {"128x128", "64x64"};
return {"64x64", "128x128"};
}
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 04: Heuristic + JSON",
"Runtime kernel selection and JSON export");
args.add_option("--arch", "gfx950", "GPU architecture");
if(!args.parse(argc, argv))
return 0;
utils::print_header("Example 04: Heuristic Selection + JSON Export");
std::string gfx_arch = args.get("--arch", "gfx950");
// Step 1: Register
std::cout << "\nStep 1: Register Kernels" << std::endl;
GroupedConvRegistry registry;
registry.set_name("heuristic_conv");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)" << std::endl;
// Step 2: Heuristic dispatcher
std::cout << "\nStep 2: Heuristic Dispatcher" << std::endl;
GroupedConvDispatcher dispatcher(&registry);
dispatcher.set_strategy(GroupedConvDispatcher::SelectionStrategy::Heuristic);
dispatcher.set_heuristic(conv_heuristic);
// Step 3: Select kernels (no GPU yet)
std::cout << "\nStep 3: Kernel Selection" << std::endl;
auto problem = create_grouped_conv2d_problem(1, 64, 128, 14, 14, 3, 3, 1, 1);
auto* selected = dispatcher.select_kernel(problem);
std::cout << " Selected: " << (selected ? selected->name() : "none") << std::endl;
// Step 4: GPU execution
std::cout << "\nStep 4: GPU Execution" << std::endl;
ck_tile::conv::ConvParam cp{
2,
static_cast<ck_tile::index_t>(1),
static_cast<ck_tile::index_t>(1),
static_cast<ck_tile::index_t>(128),
static_cast<ck_tile::index_t>(64),
{static_cast<ck_tile::index_t>(3), static_cast<ck_tile::index_t>(3)},
{static_cast<ck_tile::index_t>(14), static_cast<ck_tile::index_t>(14)},
{1, 1},
{1, 1},
{1, 1},
{1, 1}};
using InLayout = ck_tile::tensor_layout::convolution::NHWGC;
using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC;
using OutLayout = ck_tile::tensor_layout::convolution::NHWGK;
std::cout << " Creating tensors..." << std::endl;
auto in_d = ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(cp);
auto wei_d = ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(cp);
auto out_d = ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(cp);
ck_tile::HostTensor<InDataType> input(in_d);
ck_tile::HostTensor<WeiDataType> weight(wei_d);
ck_tile::HostTensor<OutDataType> output(out_d);
ck_tile::FillUniformDistribution<InDataType>{-0.5f, 0.5f}(input);
ck_tile::FillUniformDistribution<WeiDataType>{-0.5f, 0.5f}(weight);
std::cout << " Allocating device memory..." << std::endl;
ck_tile::DeviceMem in_dev(input.get_element_space_size_in_bytes());
ck_tile::DeviceMem wei_dev(weight.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_dev(output.get_element_space_size_in_bytes());
in_dev.ToDevice(input.data());
wei_dev.ToDevice(weight.data());
std::cout << " Launching kernel..." << std::endl;
float time_ms = dispatcher.run(in_dev.GetDeviceBuffer(),
wei_dev.GetDeviceBuffer(),
out_dev.GetDeviceBuffer(),
problem,
nullptr);
std::cout << " Reading back..." << std::endl;
out_dev.FromDevice(output.data());
size_t nz = 0;
for(size_t i = 0; i < output.get_element_space_size(); ++i)
if(static_cast<float>(output.data()[i]) != 0.0f)
++nz;
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms"
<< std::endl;
std::cout << " TFLOPS: " << std::setprecision(2) << calculate_conv_tflops(problem, time_ms)
<< std::endl;
std::cout << " NonZero: " << nz << "/" << output.get_element_space_size() << std::endl;
// Step 5: JSON export
std::cout << "\nStep 5: JSON Export" << std::endl;
std::string json = registry.export_json(false);
std::cout << " JSON size: " << json.size() << " bytes" << std::endl;
bool passed = nz > 0;
utils::print_separator();
std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n";
utils::print_separator();
return passed ? 0 : 1;
}

View File

@@ -0,0 +1,183 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Example 05: Backward Data with CPU Reference Validation
//
// Computes dX = ConvBwdData(dY, W) on GPU via dispatcher.run()
// and validates against ck_tile::reference_grouped_conv_bwd_data.
//
// Build: cd dispatcher/build && cmake .. && make grouped_conv_05_bwd_data
#include <hip/hip_runtime.h>
#include <iostream>
#include <iomanip>
#include <vector>
#include <cmath>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp"
#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::grouped_conv_utils;
using GroupedConvSig = grouped_conv_decl::GroupedConvSignature;
using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm;
using InDataType = ck_tile::half_t;
using WeiDataType = ck_tile::half_t;
using OutDataType = ck_tile::half_t;
DECL_GROUPED_CONV_KERNEL_SET(
bwd_data_kernels,
.add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2),
GroupedConvAlgo()
.tile(1, 128, 128)
.pipeline("compv3")
.scheduler("intrawave")
.vector_sizes(4, 8, 8),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 05: Backward Data Validation",
"dX = ConvBwdData(dY, W) with CPU reference");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("-n", "1", "Batch size");
args.add_option("-c", "64", "Input channels");
args.add_option("-k", "128", "Output channels");
args.add_option("--size", "14", "Spatial size (H=W)");
if(!args.parse(argc, argv))
return 0;
utils::print_header("Example 05: Backward Data with CPU Validation");
std::string gfx_arch = args.get("--arch", "gfx950");
int N = args.get_int("-n", 1), G = 1;
int C = args.get_int("-c", 64), K = args.get_int("-k", 128);
int Hi = args.get_int("--size", 14), Wi = Hi, Y = 3, X = 3;
// Setup
ck_tile::conv::ConvParam conv_param{
2,
static_cast<ck_tile::index_t>(G),
static_cast<ck_tile::index_t>(N),
static_cast<ck_tile::index_t>(K),
static_cast<ck_tile::index_t>(C),
{static_cast<ck_tile::index_t>(Y), static_cast<ck_tile::index_t>(X)},
{static_cast<ck_tile::index_t>(Hi), static_cast<ck_tile::index_t>(Wi)},
{1, 1},
{1, 1},
{1, 1},
{1, 1}};
using InLayout = ck_tile::tensor_layout::convolution::NHWGC;
using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC;
using OutLayout = ck_tile::tensor_layout::convolution::NHWGK;
auto in_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
auto wei_desc =
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
auto out_desc =
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
// dY (gradient from next layer) and W (weight) are inputs; dX is output
ck_tile::HostTensor<OutDataType> dy(out_desc);
ck_tile::HostTensor<WeiDataType> weight(wei_desc);
ck_tile::HostTensor<InDataType> dx_gpu(in_desc);
ck_tile::HostTensor<InDataType> dx_cpu(in_desc);
ck_tile::FillUniformDistribution<OutDataType>{-0.5f, 0.5f}(dy);
ck_tile::FillUniformDistribution<WeiDataType>{-0.5f, 0.5f}(weight);
dx_cpu.SetZero();
// CPU reference
std::cout << "\nStep 1: CPU Reference (bwd_data)\n";
std::vector<ck_tile::long_index_t> strides_v = {1, 1};
std::vector<ck_tile::long_index_t> dilations_v = {1, 1};
std::vector<ck_tile::long_index_t> left_pads_v = {1, 1};
std::vector<ck_tile::long_index_t> right_pads_v = {1, 1};
ck_tile::reference_grouped_conv_bwd_data<2, InDataType, WeiDataType, OutDataType>(
dx_cpu, weight, dy, strides_v, dilations_v, left_pads_v, right_pads_v);
std::cout << " CPU complete\n";
// GPU execution via dispatcher
std::cout << "\nStep 2: GPU Execution (via dispatcher.run)\n";
GroupedConvRegistry registry;
registry.set_name("bwd_data");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
GroupedConvDispatcher dispatcher(&registry);
auto problem =
create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardData);
auto* selected = dispatcher.select_kernel(problem);
if(!selected)
{
std::cerr << " ERROR: No bwd_data kernel found!\n";
return 1;
}
std::cout << " Selected: " << selected->name() << "\n";
ck_tile::DeviceMem dy_dev(dy.get_element_space_size_in_bytes());
ck_tile::DeviceMem wei_dev(weight.get_element_space_size_in_bytes());
ck_tile::DeviceMem dx_dev(dx_gpu.get_element_space_size_in_bytes());
dy_dev.ToDevice(dy.data());
wei_dev.ToDevice(weight.data());
// dispatcher.run(dY, W, dX, problem) for bwd_data
float time_ms = dispatcher.run(dy_dev.GetDeviceBuffer(),
wei_dev.GetDeviceBuffer(),
dx_dev.GetDeviceBuffer(),
problem,
nullptr);
dx_dev.FromDevice(dx_gpu.data());
double tflops = (time_ms > 0) ? calculate_conv_tflops(problem, time_ms) : 0;
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
// Validation
std::cout << "\nStep 3: Validation (GPU vs CPU)\n";
size_t num_elements = dx_gpu.get_element_space_size();
float max_abs = 0, max_rel = 0;
size_t mismatches = 0;
constexpr float rtol = 5e-2f, atol = 5e-2f;
for(size_t i = 0; i < num_elements; ++i)
{
float gv = static_cast<float>(dx_gpu.data()[i]);
float cv = static_cast<float>(dx_cpu.data()[i]);
float d = std::abs(gv - cv);
float r = d / (std::abs(cv) + 1e-6f);
max_abs = std::max(max_abs, d);
max_rel = std::max(max_rel, r);
if(d > atol + rtol * std::abs(cv))
++mismatches;
}
bool passed = (mismatches == 0);
std::cout << " Elements: " << num_elements << "\n";
std::cout << " Mismatches: " << mismatches << "\n";
std::cout << " Max abs diff: " << std::scientific << max_abs << "\n";
std::cout << " Max rel diff: " << std::scientific << max_rel << "\n";
utils::print_separator();
std::cout << " dX = ConvBwdData(dY, W)\n";
std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n";
utils::print_separator();
return passed ? 0 : 1;
}

View File

@@ -0,0 +1,188 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Example 06: Backward Weight with CPU Reference Validation
//
// Computes dW = ConvBwdWeight(X, dY) on GPU via dispatcher.run()
// and validates against ck_tile::reference_grouped_conv_bwd_weight.
//
// Build: cd dispatcher/build && cmake .. && make grouped_conv_06_bwd_weight
#include <hip/hip_runtime.h>
#include <iostream>
#include <iomanip>
#include <vector>
#include <cmath>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp"
#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::grouped_conv_utils;
using GroupedConvSig = grouped_conv_decl::GroupedConvSignature;
using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm;
using InDataType = ck_tile::half_t;
using WeiDataType = ck_tile::half_t;
using OutDataType = ck_tile::half_t;
DECL_GROUPED_CONV_KERNEL_SET(
bwd_weight_kernels,
.add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_weight").dims(2),
GroupedConvAlgo()
.tile(1, 128, 128)
.pipeline("compv3")
.scheduler("intrawave")
.memory_op("atomic_add")
.vector_sizes(4, 8, 8),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 06: Backward Weight Validation",
"dW = ConvBwdWeight(X, dY) with CPU reference");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("-n", "1", "Batch size");
args.add_option("-c", "64", "Input channels");
args.add_option("-k", "128", "Output channels");
args.add_option("--size", "14", "Spatial size (H=W)");
args.add_option("--split-k", "1", "Split-K factor for bwd_weight (k_batch)");
if(!args.parse(argc, argv))
return 0;
utils::print_header("Example 06: Backward Weight with CPU Validation");
std::string gfx_arch = args.get("--arch", "gfx950");
int N = args.get_int("-n", 1), G = 1;
int C = args.get_int("-c", 64), K = args.get_int("-k", 128);
int Hi = args.get_int("--size", 14), Wi = Hi, Y = 3, X = 3;
// Setup
ck_tile::conv::ConvParam conv_param{
2,
static_cast<ck_tile::index_t>(G),
static_cast<ck_tile::index_t>(N),
static_cast<ck_tile::index_t>(K),
static_cast<ck_tile::index_t>(C),
{static_cast<ck_tile::index_t>(Y), static_cast<ck_tile::index_t>(X)},
{static_cast<ck_tile::index_t>(Hi), static_cast<ck_tile::index_t>(Wi)},
{1, 1},
{1, 1},
{1, 1},
{1, 1}};
using InLayout = ck_tile::tensor_layout::convolution::NHWGC;
using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC;
using OutLayout = ck_tile::tensor_layout::convolution::NHWGK;
auto in_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
auto wei_desc =
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
auto out_desc =
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
// X (input) and dY (gradient) are inputs; dW is output
ck_tile::HostTensor<InDataType> input(in_desc);
ck_tile::HostTensor<OutDataType> dy(out_desc);
ck_tile::HostTensor<WeiDataType> dw_gpu(wei_desc);
ck_tile::HostTensor<WeiDataType> dw_cpu(wei_desc);
ck_tile::FillUniformDistribution<InDataType>{-0.5f, 0.5f}(input);
ck_tile::FillUniformDistribution<OutDataType>{-0.5f, 0.5f}(dy);
dw_cpu.SetZero();
// CPU reference
std::cout << "\nStep 1: CPU Reference (bwd_weight)\n";
std::vector<ck_tile::long_index_t> strides_v = {1, 1};
std::vector<ck_tile::long_index_t> dilations_v = {1, 1};
std::vector<ck_tile::long_index_t> left_pads_v = {1, 1};
std::vector<ck_tile::long_index_t> right_pads_v = {1, 1};
ck_tile::reference_grouped_conv_bwd_weight<2, InDataType, WeiDataType, OutDataType>(
input, dw_cpu, dy, strides_v, dilations_v, left_pads_v, right_pads_v);
std::cout << " CPU complete\n";
// GPU execution
std::cout << "\nStep 2: GPU Execution (via dispatcher.run)\n";
GroupedConvRegistry registry;
registry.set_name("bwd_weight");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
GroupedConvDispatcher dispatcher(&registry);
auto problem =
create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardWeight);
problem.split_k = args.get_int("--split-k", 1);
auto* selected = dispatcher.select_kernel(problem);
if(!selected)
{
std::cerr << " ERROR: No bwd_weight kernel found!\n";
return 1;
}
std::cout << " Selected: " << selected->name() << "\n";
ck_tile::DeviceMem in_dev(input.get_element_space_size_in_bytes());
ck_tile::DeviceMem dy_dev(dy.get_element_space_size_in_bytes());
ck_tile::DeviceMem dw_dev(dw_gpu.get_element_space_size_in_bytes());
in_dev.ToDevice(input.data());
dy_dev.ToDevice(dy.data());
if(problem.split_k > 1)
dw_dev.SetZero();
// dispatcher.run(X, dY, dW, problem) for bwd_weight
float time_ms = dispatcher.run(in_dev.GetDeviceBuffer(),
dy_dev.GetDeviceBuffer(),
dw_dev.GetDeviceBuffer(),
problem,
nullptr);
dw_dev.FromDevice(dw_gpu.data());
double tflops = (time_ms > 0) ? calculate_conv_tflops(problem, time_ms) : 0;
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
// Validation
std::cout << "\nStep 3: Validation (GPU vs CPU)\n";
size_t num_elements = dw_gpu.get_element_space_size();
float max_abs = 0, max_rel = 0;
size_t mismatches = 0;
constexpr float rtol = 5e-2f, atol = 5e-2f;
for(size_t i = 0; i < num_elements; ++i)
{
float gv = static_cast<float>(dw_gpu.data()[i]);
float cv = static_cast<float>(dw_cpu.data()[i]);
float d = std::abs(gv - cv);
float r = d / (std::abs(cv) + 1e-6f);
max_abs = std::max(max_abs, d);
max_rel = std::max(max_rel, r);
if(d > atol + rtol * std::abs(cv))
++mismatches;
}
bool passed = (mismatches == 0);
std::cout << " Elements: " << num_elements << "\n";
std::cout << " Mismatches: " << mismatches << "\n";
std::cout << " Max abs diff: " << std::scientific << max_abs << "\n";
std::cout << " Max rel diff: " << std::scientific << max_rel << "\n";
utils::print_separator();
std::cout << " dW = ConvBwdWeight(X, dY)\n";
std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n";
utils::print_separator();
return passed ? 0 : 1;
}

View File

@@ -0,0 +1,226 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Example 07: Multi-Tile Benchmark
//
// Benchmarks multiple tile configurations across ResNet-like problem sizes.
// Exposes warmup, repeat, and init method as CLI args (matching CK Tile
// example 20 patterns).
//
// Build: cd dispatcher/build && cmake .. && make grouped_conv_07_benchmark
#include <hip/hip_runtime.h>
#include <iostream>
#include <iomanip>
#include <vector>
#include <cmath>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::grouped_conv_utils;
using GroupedConvSig = grouped_conv_decl::GroupedConvSignature;
using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm;
using InDataType = ck_tile::half_t;
using WeiDataType = ck_tile::half_t;
using OutDataType = ck_tile::half_t;
// Multiple tile configurations for benchmarking
DECL_GROUPED_CONV_KERNEL_SET(
benchmark_tiles,
// Small tile - compv3
.add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
GroupedConvAlgo()
.tile(1, 64, 64)
.wave(1, 4, 1)
.warp(16, 16, 32)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle")
.vector_sizes(4, 8, 8)
.block_per_cu(1),
"gfx950")
// Medium tile - compv3
.add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
GroupedConvAlgo()
.tile(1, 128, 128)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle")
.vector_sizes(4, 8, 8)
.block_per_cu(1),
"gfx950")
// Large tile - compv4 with double smem buffer
.add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
GroupedConvAlgo()
.tile(1, 256, 256)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv4")
.scheduler("intrawave")
.epilogue("cshuffle")
.vector_sizes(4, 8, 8)
.block_per_cu(1),
"gfx950"));
int main(int argc, char* argv[])
{
utils::ExampleArgs args("Example 07: Multi-Tile Benchmark",
"Multiple tiles across ResNet-like problem sizes");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--warmup", "5", "Warmup iterations (passed to stream_config)");
args.add_option("--repeat", "20", "Benchmark iterations (passed to stream_config)");
args.add_option("--init", "0", "Init method: 0=random, 1=linear, 2=constant(1)");
if(!args.parse(argc, argv))
return 0;
utils::print_header("Example 07: Multi-Tile Benchmark");
std::string gfx_arch = args.get("--arch", "gfx950");
int warmup = args.get_int("--warmup", 5);
int repeat = args.get_int("--repeat", 20);
int init_method = args.get_int("--init", 0);
std::cout << "\n Config: warmup=" << warmup << " repeat=" << repeat << " init=" << init_method
<< "\n";
GroupedConvRegistry registry;
registry.set_name("benchmark");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
GroupedConvDispatcher dispatcher(&registry);
// ResNet-like problem sizes
struct BenchProblem
{
const char* label;
int N, C, K, Hi, Wi, Y, X;
};
BenchProblem problems[] = {
{"ResNet-stage2", 1, 64, 64, 56, 56, 3, 3},
{"ResNet-stage3", 1, 128, 128, 28, 28, 3, 3},
{"ResNet-stage4", 1, 256, 256, 14, 14, 3, 3},
{"ResNet-stage5", 1, 512, 512, 7, 7, 3, 3},
{"Pointwise-1x1", 1, 256, 256, 56, 56, 1, 1},
{"Batch-8", 8, 64, 128, 56, 56, 3, 3},
};
std::cout << "\n " << std::left << std::setw(16) << "Problem" << std::right << std::setw(5)
<< "N" << std::setw(5) << "C" << std::setw(5) << "K" << std::setw(5) << "H"
<< std::setw(5) << "W" << std::setw(4) << "F" << std::setw(10) << "Time(ms)"
<< std::setw(10) << "TFLOPS" << std::setw(10) << "Status" << "\n";
std::cout << " " << std::string(74, '-') << "\n";
bool all_pass = true;
for(const auto& bp : problems)
{
auto problem =
create_grouped_conv2d_problem(bp.N, bp.C, bp.K, bp.Hi, bp.Wi, bp.Y, bp.X, 1, 1);
problem.op = GroupedConvOp::Forward;
ck_tile::conv::ConvParam conv_param{
2,
static_cast<ck_tile::index_t>(1),
static_cast<ck_tile::index_t>(bp.N),
static_cast<ck_tile::index_t>(bp.K),
static_cast<ck_tile::index_t>(bp.C),
{static_cast<ck_tile::index_t>(bp.Y), static_cast<ck_tile::index_t>(bp.X)},
{static_cast<ck_tile::index_t>(bp.Hi), static_cast<ck_tile::index_t>(bp.Wi)},
{1, 1},
{1, 1},
{1, 1},
{1, 1}};
using InLayout = ck_tile::tensor_layout::convolution::NHWGC;
using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC;
using OutLayout = ck_tile::tensor_layout::convolution::NHWGK;
auto in_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
auto wei_desc =
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
conv_param);
auto out_desc =
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
ck_tile::HostTensor<InDataType> input(in_desc);
ck_tile::HostTensor<WeiDataType> weight(wei_desc);
ck_tile::HostTensor<OutDataType> output(out_desc);
switch(init_method)
{
case 1:
ck_tile::FillMonotonicSeq<InDataType>{0.0f, 0.001f}(input);
ck_tile::FillMonotonicSeq<WeiDataType>{0.0f, 0.001f}(weight);
break;
case 2:
ck_tile::FillConstant<InDataType>{1.0f}(input);
ck_tile::FillConstant<WeiDataType>{1.0f}(weight);
break;
default:
ck_tile::FillUniformDistribution<InDataType>{-0.5f, 0.5f}(input);
ck_tile::FillUniformDistribution<WeiDataType>{-0.5f, 0.5f}(weight);
break;
}
ck_tile::DeviceMem in_dev(input.get_element_space_size_in_bytes());
ck_tile::DeviceMem wei_dev(weight.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_dev(output.get_element_space_size_in_bytes());
in_dev.ToDevice(input.data());
wei_dev.ToDevice(weight.data());
float time_ms = 0;
bool ok = false;
try
{
time_ms = dispatcher.run(in_dev.GetDeviceBuffer(),
wei_dev.GetDeviceBuffer(),
out_dev.GetDeviceBuffer(),
problem,
nullptr);
out_dev.FromDevice(output.data());
size_t nz = 0;
for(size_t j = 0; j < output.get_element_space_size(); ++j)
if(static_cast<float>(output.data()[j]) != 0.0f)
++nz;
ok = nz > 0;
}
catch(const std::exception&)
{
ok = false;
}
double tflops = (time_ms > 0) ? calculate_conv_tflops(problem, time_ms) : 0;
std::string filter_str = std::to_string(bp.Y) + "x" + std::to_string(bp.X);
std::cout << " " << std::left << std::setw(16) << bp.label << std::right << std::setw(5)
<< bp.N << std::setw(5) << bp.C << std::setw(5) << bp.K << std::setw(5) << bp.Hi
<< std::setw(5) << bp.Wi << std::setw(4) << filter_str << std::fixed
<< std::setprecision(4) << std::setw(10) << time_ms << std::setprecision(2)
<< std::setw(10) << tflops << std::setw(10) << (ok ? "OK" : "FAIL") << "\n";
if(!ok)
all_pass = false;
}
utils::print_separator();
std::cout << " Warmup: " << warmup << ", Repeat: " << repeat << ", Init: " << init_method
<< "\n";
std::cout << " Status: " << (all_pass ? "PASS" : "FAIL") << "\n";
utils::print_separator();
return all_pass ? 0 : 1;
}

View File

@@ -0,0 +1,271 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 01: Basic Grouped Convolution
Demonstrates:
1. Three kernel configuration patterns (minimal, explicit, full ConvConfigBase)
2. Adding kernels to a registry
3. Validation and auto-correction
4. JIT compilation via registry.build()
5. GPU execution with CPU reference verification
Usage:
python3 01_basic_grouped_conv.py
python3 01_basic_grouped_conv.py --variant bwd_data
python3 01_basic_grouped_conv.py --arch gfx942
"""
import sys
import argparse
import time
import numpy as np
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
from grouped_conv_utils import (
GroupedConvKernelConfig,
GroupedConvProblem,
GroupedConvRegistry,
validate_grouped_conv_config,
auto_correct_grouped_conv_config,
detect_gpu_arch,
)
def cpu_conv2d_fwd(inp, wei, prob):
"""Naive CPU reference: 2D forward, NHWGC layout."""
N, Hi, Wi, G, Cpg = inp.shape
_, Kpg, Y, X, _ = wei.shape
Ho, Wo = prob.Ho, prob.Wo
out = np.zeros((N, Ho, Wo, G, Kpg), dtype=np.float32)
for n in range(N):
for g in range(G):
for ho in range(Ho):
for wo in range(Wo):
for k in range(Kpg):
s = 0.0
for y in range(Y):
for x in range(X):
hi = (
ho * prob.stride_h
- prob.pad_h
+ y * prob.dilation_h
)
wi = (
wo * prob.stride_w
- prob.pad_w
+ x * prob.dilation_w
)
if 0 <= hi < Hi and 0 <= wi < Wi:
for c in range(Cpg):
s += float(inp[n, hi, wi, g, c]) * float(
wei[g, k, y, x, c]
)
out[n, ho, wo, g, k] = s
return out
def main():
parser = argparse.ArgumentParser(description="Basic Grouped Conv Example")
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"])
parser.add_argument(
"--variant", default="forward", choices=["forward", "bwd_data", "bwd_weight"]
)
parser.add_argument("--ndim", type=int, default=2, choices=[2, 3])
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument(
"--workers", type=int, default=0, help="Max JIT workers (0=auto)"
)
args = parser.parse_args()
print("=" * 70)
print("Example 01: Basic Grouped Convolution")
print("=" * 70)
# =========================================================================
# Step 1: Three kernel configuration patterns
# =========================================================================
print("\n--- Step 1: Kernel Configuration Patterns ---")
# Pattern 1: MINIMAL -- only variant/dtype/arch, everything else auto-filled
config_minimal = GroupedConvKernelConfig(
variant=args.variant,
ndim_spatial=args.ndim,
arch=args.arch,
dtype=args.dtype,
)
print("\n Pattern 1: MINIMAL (defaults auto-filled)")
config_minimal.print_config(indent=" ")
# Pattern 2: EXPLICIT tile/wave/warp -- user controls tiling strategy
config_explicit = GroupedConvKernelConfig(
variant=args.variant,
ndim_spatial=args.ndim,
arch=args.arch,
dtype=args.dtype,
tile_m=1,
tile_n=64,
tile_k=64,
wave_m=1,
wave_n=4,
wave_k=1,
warp_tile_m=16,
warp_tile_n=16,
warp_tile_k=32,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
)
print("\n Pattern 2: EXPLICIT tile/wave/warp")
config_explicit.print_config(indent=" ")
# Pattern 3: FULL ConvConfigBase -- every parameter specified
config_full = GroupedConvKernelConfig(
variant=args.variant,
ndim_spatial=args.ndim,
arch=args.arch,
dtype=args.dtype,
tile_m=1,
tile_n=128,
tile_k=128,
wave_m=2,
wave_n=2,
wave_k=1,
warp_tile_m=32,
warp_tile_n=32,
warp_tile_k=16,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
num_wave_groups=1,
num_groups_to_merge=1,
)
print("\n Pattern 3: FULL (all ConvConfigBase fields)")
config_full.print_config(indent=" ")
# =========================================================================
# Step 2: Build a registry with multiple configs
# =========================================================================
print("\n--- Step 2: Build Registry ---")
registry = GroupedConvRegistry("basic_conv")
registry.add(config_minimal)
registry.add(config_explicit)
registry.add(config_full)
registry.print_registry()
# =========================================================================
# Step 3: Validate and auto-correct
# =========================================================================
print("\n--- Step 3: Validate & Auto-Correct ---")
for i, cfg in enumerate(registry.kernels):
result = validate_grouped_conv_config(cfg.to_dict())
if result.is_valid:
print(f" Config [{i}] {cfg.tile_str}: VALID")
else:
print(f" Config [{i}] {cfg.tile_str}: needs correction")
corrected, result = auto_correct_grouped_conv_config(cfg.to_dict())
print(f" After correction: valid={result.is_valid}")
# =========================================================================
# Step 4: JIT compile via registry.build()
# =========================================================================
print("\n--- Step 4: JIT Build (via registry.build()) ---")
# Use only the first config for the actual GPU run
jit_reg = GroupedConvRegistry("jit")
jit_reg.add(config_minimal)
workers = args.workers if args.workers > 0 else None
t0 = time.perf_counter()
runners = jit_reg.build(verbose=False, max_workers=workers)
jit_build_s = time.perf_counter() - t0
key = (args.variant, args.ndim)
if key not in runners:
print(" JIT build failed")
return 1
runner = runners[key]
print(f" JIT build: {jit_build_s:.3f} s")
print(f" Library: {runner.library_path}")
print(f" Kernels: {runner.lib.kernel_names()}")
# =========================================================================
# Step 5: Define problem + GPU execution
# =========================================================================
print("\n--- Step 5: GPU Execution ---")
prob = GroupedConvProblem(
N=1,
C=64,
K=128,
Hi=16,
Wi=16,
Y=3,
X=3,
stride_h=1,
stride_w=1,
pad_h=1,
pad_w=1,
direction=args.variant,
)
prob.print_problem()
inp = np.random.uniform(-0.5, 0.5, prob.input_shape()).astype(np.float16)
wei = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(np.float16)
res = runner.run(inp, wei, prob)
if not res.success:
print(f" GPU execution failed: {res.error}")
runner.cleanup()
return 1
print(f" Time: {res.time_ms:.4f} ms")
print(f" TFLOPS: {res.tflops:.2f}")
print(
f" Output: shape={res.output.shape}, range=[{res.output.min():.3f}, {res.output.max():.3f}]"
)
# =========================================================================
# Step 6: CPU reference (forward 2D only)
# =========================================================================
verified = False
if args.variant == "forward" and args.ndim == 2:
print("\n--- Step 6: CPU Reference Verification ---")
ref = cpu_conv2d_fwd(inp, wei, prob)
gpu_f32 = res.output.astype(np.float32)
diff = np.abs(gpu_f32 - ref)
max_abs = diff.max()
max_rel = (diff / (np.abs(ref) + 1e-6)).max()
match = np.allclose(gpu_f32, ref, atol=0.05, rtol=0.05)
print(f" max_abs_diff: {max_abs:.6f}")
print(f" max_rel_diff: {max_rel:.6f}")
print(f" Match: {match}")
verified = match
runner.cleanup()
# Summary
print("\n" + "=" * 70)
status = (
"PASS" if res.success and (verified or args.variant != "forward") else "FAIL"
)
print(f" Status: {status}")
print(
f" {config_minimal.name} | {prob.gflops:.2f} GFLOPs | {res.tflops:.2f} TFLOPS"
)
print(f" JIT build time: {jit_build_s:.3f} s")
print(f" Registry: {len(registry)} configs (3 patterns demonstrated)")
print("=" * 70)
return 0 if status == "PASS" else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,222 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 02: Forward Convolution (2D + 3D)
Declares forward kernels with explicit tile/wave/warp/pipeline parameters,
builds a registry, JIT compiles, runs on GPU, and validates against CPU reference.
Usage:
python3 02_forward.py
python3 02_forward.py --arch gfx942
"""
import sys
import argparse
import time
import numpy as np
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
from grouped_conv_utils import (
GroupedConvKernelConfig,
GroupedConvProblem,
GroupedConvRegistry,
detect_gpu_arch,
)
def cpu_conv2d_fwd(inp, wei, prob):
"""Naive CPU reference: 2D forward, NHWGC layout."""
N, Hi, Wi, G, C = inp.shape
_, Kpg, Y, X, _ = wei.shape
Ho, Wo = prob.Ho, prob.Wo
out = np.zeros((N, Ho, Wo, G, Kpg), dtype=np.float32)
for n in range(N):
for g in range(G):
for ho in range(Ho):
for wo in range(Wo):
for k in range(Kpg):
s = 0.0
for y in range(Y):
for x in range(X):
hi = ho * prob.stride_h - prob.pad_h + y
wi = wo * prob.stride_w - prob.pad_w + x
if 0 <= hi < Hi and 0 <= wi < Wi:
for c in range(C):
s += float(inp[n, hi, wi, g, c]) * float(
wei[g, k, y, x, c]
)
out[n, ho, wo, g, k] = s
return out
def main():
parser = argparse.ArgumentParser(description="Forward Convolution (2D + 3D)")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"])
parser.add_argument(
"--workers", type=int, default=0, help="Max JIT workers (0=auto)"
)
args = parser.parse_args()
arch = args.arch
print("=" * 70)
print("Example 02: Forward Convolution (2D + 3D)")
print("=" * 70)
print(f"\n Arch: {arch}, Dtype: {args.dtype}")
# =========================================================================
# Step 1: Declare forward kernels with explicit parameters
# =========================================================================
print("\n--- Step 1: Declare Forward Kernels ---")
reg = GroupedConvRegistry("forward_conv")
# Forward 2D: compv4, 128x128 tile, wave 2x2x1, warp 32x32x16
reg.add(
GroupedConvKernelConfig(
variant="forward",
ndim_spatial=2,
arch=arch,
dtype=args.dtype,
tile_m=1,
tile_n=128,
tile_k=128,
wave_m=2,
wave_n=2,
wave_k=1,
warp_tile_m=32,
warp_tile_n=32,
warp_tile_k=16,
pipeline="compv4",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
)
)
# Forward 3D: compv3, 64x64 tile, wave 1x4x1, warp 16x16x32
reg.add(
GroupedConvKernelConfig(
variant="forward",
ndim_spatial=3,
arch=arch,
dtype=args.dtype,
tile_m=1,
tile_n=64,
tile_k=64,
wave_m=1,
wave_n=4,
wave_k=1,
warp_tile_m=16,
warp_tile_n=16,
warp_tile_k=32,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
)
)
reg.print_registry()
# =========================================================================
# Step 2: JIT build via registry
# =========================================================================
print("\n--- Step 2: JIT Build ---")
workers = args.workers if args.workers > 0 else None
t0 = time.perf_counter()
runners = reg.build(verbose=False, max_workers=workers)
jit_s = time.perf_counter() - t0
print(f" Built {len(runners)} runners in {jit_s:.1f}s")
for key in [("forward", 2), ("forward", 3)]:
tag = "OK" if key in runners else "FAILED"
print(f" {key[0]} {key[1]}D: {tag}")
if ("forward", 2) not in runners:
print(" ERROR: forward 2D JIT failed")
return 1
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
# =========================================================================
# Step 3: Forward 2D -- GPU + CPU reference
# =========================================================================
print("\n--- Step 3: Forward 2D ---")
prob_2d = GroupedConvProblem(
N=1, C=64, K=64, Hi=8, Wi=8, Y=3, X=3, pad_h=1, pad_w=1, direction="forward"
)
prob_2d.print_problem()
x = np.random.uniform(-0.5, 0.5, prob_2d.input_shape()).astype(np_dtype)
w = np.random.uniform(-0.5, 0.5, prob_2d.weight_shape()).astype(np_dtype)
res = runners[("forward", 2)].run(x, w, prob_2d)
print(f" Time: {res.time_ms:.4f} ms")
print(f" TFLOPS: {res.tflops:.2f}")
print(
f" Output: shape={res.output.shape}, nonzero={np.count_nonzero(res.output)}/{res.output.size}"
)
ref = cpu_conv2d_fwd(x, w, prob_2d)
diff = np.abs(res.output.astype(np.float32) - ref)
match_2d = np.allclose(res.output.astype(np.float32), ref, atol=0.05)
print(f" CPU ref: max_abs={diff.max():.6f}, match={match_2d}")
# =========================================================================
# Step 4: Forward 3D -- GPU + non-zero check
# =========================================================================
ok_3d = True
if ("forward", 3) in runners:
print("\n--- Step 4: Forward 3D ---")
prob_3d = GroupedConvProblem(
N=1,
C=64,
K=64,
Di=8,
Hi=8,
Wi=8,
Z=3,
Y=3,
X=3,
pad_d=1,
pad_h=1,
pad_w=1,
direction="forward",
)
prob_3d.print_problem()
x3 = np.random.uniform(-0.5, 0.5, prob_3d.input_shape()).astype(np_dtype)
w3 = np.random.uniform(-0.5, 0.5, prob_3d.weight_shape()).astype(np_dtype)
res3 = runners[("forward", 3)].run(x3, w3, prob_3d)
nz = np.count_nonzero(res3.output)
ok_3d = res3.success and nz > 0
print(f" Time: {res3.time_ms:.4f} ms")
print(f" TFLOPS: {res3.tflops:.2f}")
print(f" NonZero: {nz}/{res3.output.size}")
for r in runners.values():
r.cleanup()
passed = res.success and match_2d and ok_3d
print("\n" + "=" * 70)
print(f" Forward 2D: {'PASS' if match_2d else 'FAIL'} (CPU validated)")
print(f" Forward 3D: {'PASS' if ok_3d else 'FAIL'} (non-zero check)")
print(f" JIT build: {jit_s:.1f}s")
print(f" Status: {'PASS' if passed else 'FAIL'}")
print("=" * 70)
return 0 if passed else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,214 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 03: Backward Data Convolution (2D + 3D)
dX = ConvBwdData(dY, W)
Declares backward-data kernels with explicit parameters,
builds a registry, JIT compiles, runs on GPU, and validates
against a CPU reference.
Usage:
python3 03_bwd_data.py
"""
import sys
import argparse
import time
import numpy as np
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
from grouped_conv_utils import (
GroupedConvKernelConfig,
GroupedConvProblem,
GroupedConvRegistry,
detect_gpu_arch,
)
def cpu_conv2d_bwd_data(dy, wei, prob):
"""CPU ref: compute dX from dY and W."""
N, Ho, Wo, G, Kpg = dy.shape
_, _, Y, X, C = wei.shape
Hi, Wi = prob.Hi, prob.Wi
dx = np.zeros((N, Hi, Wi, G, C), dtype=np.float32)
for n in range(N):
for g in range(G):
for hi in range(Hi):
for wi in range(Wi):
for c in range(C):
s = 0.0
for y in range(Y):
for x in range(X):
ho = hi + prob.pad_h - y
wo = wi + prob.pad_w - x
if ho % prob.stride_h == 0 and wo % prob.stride_w == 0:
ho //= prob.stride_h
wo //= prob.stride_w
if 0 <= ho < Ho and 0 <= wo < Wo:
for k in range(Kpg):
s += float(dy[n, ho, wo, g, k]) * float(
wei[g, k, y, x, c]
)
dx[n, hi, wi, g, c] = s
return dx
def main():
parser = argparse.ArgumentParser(description="Backward Data (2D + 3D)")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"])
parser.add_argument("--workers", type=int, default=0)
args = parser.parse_args()
arch = args.arch
print("=" * 70)
print("Example 03: Backward Data Convolution (2D + 3D)")
print("=" * 70)
print(f"\n Arch: {arch}, Dtype: {args.dtype}")
print(" dX = ConvBwdData(dY, W)")
# =========================================================================
# Step 1: Declare bwd_data kernels
# =========================================================================
print("\n--- Step 1: Declare BwdData Kernels ---")
reg = GroupedConvRegistry("bwd_data_conv")
# BwdData 2D: compv3, 128x128 tile
reg.add(
GroupedConvKernelConfig(
variant="bwd_data",
ndim_spatial=2,
arch=arch,
dtype=args.dtype,
tile_m=1,
tile_n=128,
tile_k=128,
wave_m=2,
wave_n=2,
wave_k=1,
warp_tile_m=32,
warp_tile_n=32,
warp_tile_k=16,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
)
)
# BwdData 3D: compv3, 64x64 tile
reg.add(
GroupedConvKernelConfig(
variant="bwd_data",
ndim_spatial=3,
arch=arch,
dtype=args.dtype,
tile_m=1,
tile_n=64,
tile_k=64,
wave_m=1,
wave_n=4,
wave_k=1,
warp_tile_m=16,
warp_tile_n=16,
warp_tile_k=32,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
)
)
reg.print_registry()
# =========================================================================
# Step 2: JIT build
# =========================================================================
print("\n--- Step 2: JIT Build ---")
workers = args.workers if args.workers > 0 else None
t0 = time.perf_counter()
runners = reg.build(verbose=False, max_workers=workers)
jit_s = time.perf_counter() - t0
print(f" Built {len(runners)} runners in {jit_s:.1f}s")
if ("bwd_data", 2) not in runners:
print(" ERROR: bwd_data 2D JIT failed")
return 1
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
# =========================================================================
# Step 3: BwdData 2D -- GPU + CPU reference
# =========================================================================
print("\n--- Step 3: Backward Data 2D ---")
prob = GroupedConvProblem(
N=1, C=32, K=32, Hi=8, Wi=8, Y=3, X=3, pad_h=1, pad_w=1, direction="bwd_data"
)
prob.print_problem()
dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(np_dtype)
w = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(np_dtype)
res = runners[("bwd_data", 2)].run(dy, w, prob)
print(f" Time: {res.time_ms:.4f} ms")
print(f" TFLOPS: {res.tflops:.2f}")
print(f" NonZero: {np.count_nonzero(res.output)}/{res.output.size}")
ref = cpu_conv2d_bwd_data(dy, w, prob)
diff = np.abs(res.output.astype(np.float32) - ref)
match_2d = np.allclose(res.output.astype(np.float32), ref, atol=0.1)
print(f" CPU ref: max_abs={diff.max():.6f}, match={match_2d}")
# =========================================================================
# Step 4: BwdData 3D -- GPU + non-zero check
# =========================================================================
ok_3d = True
if ("bwd_data", 3) in runners:
print("\n--- Step 4: Backward Data 3D ---")
prob3 = GroupedConvProblem(
N=1,
C=32,
K=32,
Di=6,
Hi=6,
Wi=6,
Z=3,
Y=3,
X=3,
pad_d=1,
pad_h=1,
pad_w=1,
direction="bwd_data",
)
dy3 = np.random.uniform(-0.5, 0.5, prob3.output_shape()).astype(np_dtype)
w3 = np.random.uniform(-0.5, 0.5, prob3.weight_shape()).astype(np_dtype)
res3 = runners[("bwd_data", 3)].run(dy3, w3, prob3)
nz = np.count_nonzero(res3.output)
ok_3d = res3.success and nz > 0
print(f" Time: {res3.time_ms:.4f} ms, NonZero: {nz}/{res3.output.size}")
for r in runners.values():
r.cleanup()
passed = res.success and match_2d and ok_3d
print("\n" + "=" * 70)
print(f" BwdData 2D: {'PASS' if match_2d else 'FAIL'} (CPU validated)")
print(f" BwdData 3D: {'PASS' if ok_3d else 'FAIL'}")
print(f" Status: {'PASS' if passed else 'FAIL'}")
print("=" * 70)
return 0 if passed else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,224 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 04: Backward Weight Convolution (2D + 3D)
dW = ConvBwdWeight(X, dY)
Declares backward-weight kernels with explicit parameters,
builds a registry, JIT compiles, runs on GPU, and validates
against a CPU reference.
Usage:
python3 04_bwd_weight.py
"""
import sys
import argparse
import time
import numpy as np
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
from grouped_conv_utils import (
GroupedConvKernelConfig,
GroupedConvProblem,
GroupedConvRegistry,
detect_gpu_arch,
)
def cpu_conv2d_bwd_weight(x, dy, prob):
"""CPU ref: compute dW from X and dY."""
N, Hi, Wi, G, C = x.shape
_, Ho, Wo, _, Kpg = dy.shape
Y, X_ = prob.Y, prob.X
dw = np.zeros((G, Kpg, Y, X_, C), dtype=np.float32)
for g in range(G):
for k in range(Kpg):
for y in range(Y):
for xf in range(X_):
for c in range(C):
s = 0.0
for n in range(N):
for ho in range(Ho):
for wo in range(Wo):
hi = ho * prob.stride_h - prob.pad_h + y
wi = wo * prob.stride_w - prob.pad_w + xf
if 0 <= hi < Hi and 0 <= wi < Wi:
s += float(x[n, hi, wi, g, c]) * float(
dy[n, ho, wo, g, k]
)
dw[g, k, y, xf, c] = s
return dw
def main():
parser = argparse.ArgumentParser(description="Backward Weight (2D + 3D)")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"])
parser.add_argument("--workers", type=int, default=0)
parser.add_argument(
"--split-k", type=int, default=1, help="Split-K factor for bwd_weight (k_batch)"
)
args = parser.parse_args()
arch = args.arch
print("=" * 70)
print("Example 04: Backward Weight Convolution (2D + 3D)")
print("=" * 70)
print(f"\n Arch: {arch}, Dtype: {args.dtype}")
print(" dW = ConvBwdWeight(X, dY)")
# =========================================================================
# Step 1: Declare bwd_weight kernels
# =========================================================================
print("\n--- Step 1: Declare BwdWeight Kernels ---")
reg = GroupedConvRegistry("bwd_weight_conv")
# BwdWeight 2D: compv3, 128x128 tile
reg.add(
GroupedConvKernelConfig(
variant="bwd_weight",
ndim_spatial=2,
arch=arch,
dtype=args.dtype,
tile_m=1,
tile_n=128,
tile_k=128,
wave_m=2,
wave_n=2,
wave_k=1,
warp_tile_m=32,
warp_tile_n=32,
warp_tile_k=16,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
)
)
# BwdWeight 3D: compv3, 64x64 tile
reg.add(
GroupedConvKernelConfig(
variant="bwd_weight",
ndim_spatial=3,
arch=arch,
dtype=args.dtype,
tile_m=1,
tile_n=64,
tile_k=64,
wave_m=1,
wave_n=4,
wave_k=1,
warp_tile_m=16,
warp_tile_n=16,
warp_tile_k=32,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
)
)
reg.print_registry()
# =========================================================================
# Step 2: JIT build
# =========================================================================
print("\n--- Step 2: JIT Build ---")
workers = args.workers if args.workers > 0 else None
t0 = time.perf_counter()
runners = reg.build(verbose=False, max_workers=workers)
jit_s = time.perf_counter() - t0
print(f" Built {len(runners)} runners in {jit_s:.1f}s")
if ("bwd_weight", 2) not in runners:
print(" ERROR: bwd_weight 2D JIT failed")
return 1
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
# =========================================================================
# Step 3: BwdWeight 2D -- GPU + CPU reference
# =========================================================================
print("\n--- Step 3: Backward Weight 2D ---")
prob = GroupedConvProblem(
N=1,
C=32,
K=32,
Hi=8,
Wi=8,
Y=3,
X=3,
pad_h=1,
pad_w=1,
direction="bwd_weight",
split_k=args.split_k,
)
prob.print_problem()
x = np.random.uniform(-0.5, 0.5, prob.input_shape()).astype(np_dtype)
dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(np_dtype)
res = runners[("bwd_weight", 2)].run(x, dy, prob)
print(f" Time: {res.time_ms:.4f} ms")
print(f" TFLOPS: {res.tflops:.2f}")
print(f" NonZero: {np.count_nonzero(res.output)}/{res.output.size}")
ref = cpu_conv2d_bwd_weight(x, dy, prob)
diff = np.abs(res.output.astype(np.float32) - ref)
match_2d = np.allclose(res.output.astype(np.float32), ref, atol=0.5)
print(f" CPU ref: max_abs={diff.max():.6f}, match={match_2d}")
# =========================================================================
# Step 4: BwdWeight 3D -- GPU + non-zero check
# =========================================================================
ok_3d = True
if ("bwd_weight", 3) in runners:
print("\n--- Step 4: Backward Weight 3D ---")
prob3 = GroupedConvProblem(
N=1,
C=32,
K=32,
Di=6,
Hi=6,
Wi=6,
Z=3,
Y=3,
X=3,
pad_d=1,
pad_h=1,
pad_w=1,
direction="bwd_weight",
)
x3 = np.random.uniform(-0.5, 0.5, prob3.input_shape()).astype(np_dtype)
dy3 = np.random.uniform(-0.5, 0.5, prob3.output_shape()).astype(np_dtype)
res3 = runners[("bwd_weight", 3)].run(x3, dy3, prob3)
nz = np.count_nonzero(res3.output)
ok_3d = res3.success and nz > 0
print(f" Time: {res3.time_ms:.4f} ms, NonZero: {nz}/{res3.output.size}")
for r in runners.values():
r.cleanup()
passed = res.success and match_2d and ok_3d
print("\n" + "=" * 70)
print(f" BwdWeight 2D: {'PASS' if match_2d else 'FAIL'} (CPU validated)")
print(f" BwdWeight 3D: {'PASS' if ok_3d else 'FAIL'}")
print(f" Status: {'PASS' if passed else 'FAIL'}")
print("=" * 70)
return 0 if passed else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,318 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 05: Multi-Problem GPU Benchmark
Declares kernels with explicit tile/wave/warp/pipeline parameters for
all directions, builds registries, JIT compiles, and benchmarks across
ResNet-like problem sizes with configurable warmup/repeat.
Usage:
python3 05_benchmark.py
python3 05_benchmark.py --warmup 3 --repeat 10
python3 05_benchmark.py --workers 4
"""
import sys
import argparse
import time
import numpy as np
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
from grouped_conv_utils import (
GroupedConvKernelConfig,
GroupedConvProblem,
GroupedConvRegistry,
detect_gpu_arch,
)
def compute_bytes(prob, dtype_bytes=2):
in_elems = 1
for d in prob.input_shape():
in_elems *= d
wei_elems = 1
for d in prob.weight_shape():
wei_elems *= d
out_elems = 1
for d in prob.output_shape():
out_elems *= d
return (in_elems + wei_elems + out_elems) * dtype_bytes
def main():
parser = argparse.ArgumentParser(description="Multi-Problem GPU Benchmark")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"])
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
parser.add_argument("--repeat", type=int, default=5, help="Benchmark iterations")
parser.add_argument(
"--workers", type=int, default=0, help="Max JIT workers (0=auto)"
)
args = parser.parse_args()
print("=" * 70)
print("Example 05: Multi-Problem GPU Benchmark")
print("=" * 70)
print(f"\n Arch: {args.arch}, Dtype: {args.dtype}")
print(f" Warmup: {args.warmup}, Repeat: {args.repeat}")
# =========================================================================
# Step 1: Declare all kernels with explicit parameters
# =========================================================================
print("\n--- Step 1: Declare Kernels ---")
reg = GroupedConvRegistry("benchmark")
# Forward 2D: compv4, 128x128 tile
reg.add(
GroupedConvKernelConfig(
variant="forward",
ndim_spatial=2,
arch=args.arch,
dtype=args.dtype,
tile_m=1,
tile_n=128,
tile_k=128,
wave_m=2,
wave_n=2,
wave_k=1,
warp_tile_m=32,
warp_tile_n=32,
warp_tile_k=16,
pipeline="compv4",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
)
)
# Forward 3D: compv3, 64x64 tile
reg.add(
GroupedConvKernelConfig(
variant="forward",
ndim_spatial=3,
arch=args.arch,
dtype=args.dtype,
tile_m=1,
tile_n=64,
tile_k=64,
wave_m=1,
wave_n=4,
wave_k=1,
warp_tile_m=16,
warp_tile_n=16,
warp_tile_k=32,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
)
)
# BwdData 2D: compv3, 128x128 tile
reg.add(
GroupedConvKernelConfig(
variant="bwd_data",
ndim_spatial=2,
arch=args.arch,
dtype=args.dtype,
tile_m=1,
tile_n=128,
tile_k=128,
wave_m=2,
wave_n=2,
wave_k=1,
warp_tile_m=32,
warp_tile_n=32,
warp_tile_k=16,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
)
)
# BwdWeight 2D: compv3, 128x128 tile
reg.add(
GroupedConvKernelConfig(
variant="bwd_weight",
ndim_spatial=2,
arch=args.arch,
dtype=args.dtype,
tile_m=1,
tile_n=128,
tile_k=128,
wave_m=2,
wave_n=2,
wave_k=1,
warp_tile_m=32,
warp_tile_n=32,
warp_tile_k=16,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
)
)
reg.print_registry()
# =========================================================================
# Step 2: JIT build
# =========================================================================
print("\n--- Step 2: JIT Build ---")
workers = args.workers if args.workers > 0 else None
t0 = time.perf_counter()
runner_by_key = reg.build(verbose=False, max_workers=workers)
jit_s = time.perf_counter() - t0
for key in [("forward", 2), ("forward", 3), ("bwd_data", 2), ("bwd_weight", 2)]:
tag = "OK" if key in runner_by_key else "FAILED"
print(f" {key[0]:12s} {key[1]}D: {tag}")
print(f" JIT build time: {jit_s:.3f} s")
missing = [
k
for k in [("forward", 2), ("forward", 3), ("bwd_data", 2), ("bwd_weight", 2)]
if k not in runner_by_key
]
if missing:
print(f"\n ERROR: missing {missing}")
return 1
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
def bench_run(runner, inp, wei, prob):
for _ in range(args.warmup):
runner.run(inp, wei, prob)
times = []
for _ in range(args.repeat):
r = runner.run(inp, wei, prob)
if r.success:
times.append(r.time_ms)
if not times:
return 0.0, 0.0
return min(times), sum(times) / len(times)
# =========================================================================
# Step 3: 2D Forward benchmark
# =========================================================================
print("\n--- Step 3: Forward 2D Benchmark ---")
print(
f"{'Problem':<18} {'N':>3} {'C':>4} {'K':>4} {'H':>3} {'W':>3} "
f"{'F':>3} {'Min(ms)':>9} {'Avg(ms)':>9} {'TFLOPS':>8} {'GB/s':>8}"
)
print("-" * 85)
all_ok = True
for label, n, c, k, h, w, y, x, s, p in [
("ResNet-stage2", 1, 64, 64, 56, 56, 3, 3, 1, 1),
("ResNet-stage3", 1, 128, 128, 28, 28, 3, 3, 1, 1),
("ResNet-stage4", 1, 256, 256, 14, 14, 3, 3, 1, 1),
("ResNet-stage5", 1, 512, 512, 7, 7, 3, 3, 1, 1),
("Pointwise-1x1", 1, 256, 256, 56, 56, 1, 1, 1, 0),
("Batch-8", 8, 64, 128, 56, 56, 3, 3, 1, 1),
("Batch-32", 32, 64, 128, 56, 56, 3, 3, 1, 1),
]:
prob = GroupedConvProblem(
N=n,
C=c,
K=k,
Hi=h,
Wi=w,
Y=y,
X=x,
stride_h=s,
stride_w=s,
pad_h=p,
pad_w=p,
direction="forward",
)
inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype)
wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype)
min_ms, avg_ms = bench_run(runner_by_key[("forward", 2)], inp, wei, prob)
if avg_ms > 0:
tflops = prob.flops / (avg_ms * 1e9)
bw = compute_bytes(prob) / (avg_ms * 1e6)
print(
f"{label:<18} {n:>3} {c:>4} {k:>4} {h:>3} {w:>3} "
f"{y}x{x} {min_ms:>9.4f} {avg_ms:>9.4f} {tflops:>8.2f} {bw:>8.1f}"
)
else:
all_ok = False
# =========================================================================
# Step 4: 3D Forward
# =========================================================================
print("\n--- Step 4: Forward 3D ---")
for label, n, c, k, d, h, w, z, y, x in [
("3D-small", 1, 64, 64, 8, 16, 16, 3, 3, 3),
("3D-medium", 1, 64, 128, 16, 32, 32, 3, 3, 3),
]:
prob = GroupedConvProblem(
N=n, C=c, K=k, Di=d, Hi=h, Wi=w, Z=z, Y=y, X=x, direction="forward"
)
inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype)
wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype)
min_ms, avg_ms = bench_run(runner_by_key[("forward", 3)], inp, wei, prob)
if avg_ms > 0:
tflops = prob.flops / (avg_ms * 1e9)
print(f" {label:<14} {min_ms:.4f} / {avg_ms:.4f} ms {tflops:.2f} TFLOPS")
# =========================================================================
# Step 5: Backward directions
# =========================================================================
print("\n--- Step 5: Backward Directions ---")
for label, direction in [
("bwd_data ResNet-s3", "bwd_data"),
("bwd_weight ResNet-s3", "bwd_weight"),
]:
prob = GroupedConvProblem(
N=1,
C=128,
K=128,
Hi=28,
Wi=28,
Y=3,
X=3,
stride_h=1,
stride_w=1,
pad_h=1,
pad_w=1,
direction=direction,
)
inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype)
wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype)
min_ms, avg_ms = bench_run(runner_by_key[(direction, 2)], inp, wei, prob)
if avg_ms > 0:
tflops = prob.flops / (avg_ms * 1e9)
print(
f" {label:<14} {direction:>12} {min_ms:.4f} / {avg_ms:.4f} ms {tflops:.2f} TFLOPS"
)
for runner in runner_by_key.values():
runner.cleanup()
print("\n" + "=" * 70)
print(f" JIT build: {jit_s:.3f} s")
print(f" Warmup: {args.warmup}, Repeat: {args.repeat}")
print(f" Status: {'PASS' if all_ok else 'FAIL'}")
print("=" * 70)
return 0 if all_ok else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,274 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 06: Registry, Heuristic Selection & JSON Export
Declares multiple kernel configurations with different tile sizes,
builds a registry, demonstrates heuristic runtime kernel selection,
JSON round-trip, and GPU execution.
Usage:
python3 06_registry_json.py
python3 06_registry_json.py --workers 4
"""
import sys
import time
import argparse
import numpy as np
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
from grouped_conv_utils import (
GroupedConvKernelConfig,
GroupedConvProblem,
GroupedConvRegistry,
detect_gpu_arch,
)
def conv_heuristic(problem):
spatial = problem.Ho * problem.Wo
if spatial > 400:
return ["256", "128", "64"]
return ["64", "128", "256"]
def main():
parser = argparse.ArgumentParser(description="Registry, Heuristic & JSON")
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"])
parser.add_argument("--workers", type=int, default=0)
args = parser.parse_args()
arch = args.arch
print("=" * 70)
print("Example 06: Registry, Heuristic Selection & JSON Export")
print("=" * 70)
print(f"\n Arch: {arch}, Dtype: {args.dtype}")
# Step 1: Declare kernels with full explicit parameters
print("\n--- Step 1: Declare Kernels + Build Registry ---")
reg = GroupedConvRegistry("conv_tiles")
reg.add(
GroupedConvKernelConfig(
variant="forward",
ndim_spatial=2,
arch=arch,
dtype=args.dtype,
tile_m=1,
tile_n=256,
tile_k=256,
wave_m=2,
wave_n=2,
wave_k=1,
warp_tile_m=32,
warp_tile_n=32,
warp_tile_k=16,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
num_wave_groups=1,
num_groups_to_merge=1,
)
)
reg.add(
GroupedConvKernelConfig(
variant="forward",
ndim_spatial=2,
arch=arch,
dtype=args.dtype,
tile_m=1,
tile_n=128,
tile_k=128,
wave_m=2,
wave_n=2,
wave_k=1,
warp_tile_m=32,
warp_tile_n=32,
warp_tile_k=16,
pipeline="compv4",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
num_wave_groups=1,
num_groups_to_merge=1,
)
)
reg.add(
GroupedConvKernelConfig(
variant="forward",
ndim_spatial=2,
arch=arch,
dtype=args.dtype,
tile_m=1,
tile_n=64,
tile_k=64,
wave_m=1,
wave_n=4,
wave_k=1,
warp_tile_m=16,
warp_tile_n=16,
warp_tile_k=32,
pipeline="compv3",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
num_wave_groups=1,
num_groups_to_merge=1,
)
)
reg.print_registry()
# Step 2: Heuristic kernel selection
print("\n--- Step 2: Heuristic Kernel Selection ---")
problems = [
(
"small_7x7",
GroupedConvProblem(
N=1,
C=512,
K=512,
Hi=7,
Wi=7,
Y=3,
X=3,
pad_h=1,
pad_w=1,
direction="forward",
),
),
(
"medium_14x14",
GroupedConvProblem(
N=1,
C=256,
K=256,
Hi=14,
Wi=14,
Y=3,
X=3,
pad_h=1,
pad_w=1,
direction="forward",
),
),
(
"large_56x56",
GroupedConvProblem(
N=1,
C=64,
K=128,
Hi=56,
Wi=56,
Y=3,
X=3,
pad_h=1,
pad_w=1,
direction="forward",
),
),
]
print(f" {'Problem':<16} {'Spatial':>8} {'Selected Kernel':<50}")
print(f" {'-' * 74}")
for label, prob in problems:
selected = reg.select(prob, heuristic=conv_heuristic)
spatial = prob.Ho * prob.Wo
sel_name = selected.name if selected else "none"
print(f" {label:<16} {spatial:>8} {sel_name:<50}")
# Step 3: JSON round-trip
print("\n--- Step 3: JSON Round-Trip ---")
json_str = reg.to_json()
print(f" Exported: {len(json_str)} bytes, {len(reg)} kernels")
imported = GroupedConvRegistry.from_json(json_str)
print(f" Imported: {len(imported)} kernels")
orig = reg.kernels[0]
imp = imported.kernels[0]
rt_ok = (
orig.vector_size_a == imp.vector_size_a
and orig.block_per_cu == imp.block_per_cu
and orig.tile_n == imp.tile_n
)
print(f" Full fields round-trip: {'OK' if rt_ok else 'FAIL'}")
# Step 4: JIT build + GPU execution
print("\n--- Step 4: JIT Build + GPU Execution ---")
workers = args.workers if args.workers > 0 else None
jit_reg = GroupedConvRegistry("jit_conv")
jit_reg.add(
GroupedConvKernelConfig(
variant="forward",
ndim_spatial=2,
arch=arch,
dtype=args.dtype,
tile_m=1,
tile_n=128,
tile_k=128,
wave_m=2,
wave_n=2,
wave_k=1,
warp_tile_m=32,
warp_tile_n=32,
warp_tile_k=16,
pipeline="compv4",
scheduler="intrawave",
epilogue="cshuffle",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
)
)
t0 = time.perf_counter()
runners = jit_reg.build(verbose=False, max_workers=workers)
jit_s = time.perf_counter() - t0
if ("forward", 2) not in runners:
print(" JIT build failed")
return 1
runner = runners[("forward", 2)]
print(f" JIT build: {jit_s:.3f} s")
print(f" Library: {runner.library_path}")
prob = GroupedConvProblem(
N=1, C=128, K=128, Hi=16, Wi=16, Y=3, X=3, pad_h=1, pad_w=1, direction="forward"
)
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype)
wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype)
res = runner.run(inp, wei, prob)
runner.cleanup()
if res.success:
print(f" Time: {res.time_ms:.4f} ms")
print(f" TFLOPS: {res.tflops:.2f}")
print(f" NonZero: {np.count_nonzero(res.output)}/{res.output.size}")
gpu_ok = res.success
print("\n" + "=" * 70)
print(f" Registry: {len(reg)} kernels (3 tile configs)")
print(" Heuristic: spatial-based selection demonstrated")
print(f" JSON: round-trip {'OK' if rt_ok else 'FAIL'}")
print(f" GPU: {'OK' if gpu_ok else 'FAIL'}")
print(f" Status: {'PASS' if gpu_ok and rt_ok else 'FAIL'}")
print("=" * 70)
return 0 if gpu_ok and rt_ok else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -3,9 +3,17 @@
#pragma once
/// Main dispatcher header - includes all core components
/// Use this for convenient access to the full dispatcher API
/// Full dispatcher header - includes ALL operation types.
/// For minimal includes, use the per-operation headers instead:
/// ck_tile/dispatcher_gemm.hpp -- GEMM only
/// ck_tile/dispatcher_conv.hpp -- Grouped Convolution only
// Core (needed by all ops)
#include "ck_tile/dispatcher/base_registry.hpp"
#include "ck_tile/dispatcher/dispatcher_error.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
// GEMM
#include "ck_tile/dispatcher/kernel_key.hpp"
#include "ck_tile/dispatcher/kernel_config.hpp"
#include "ck_tile/dispatcher/kernel_decl.hpp"
@@ -13,7 +21,15 @@
#include "ck_tile/dispatcher/kernel_instance.hpp"
#include "ck_tile/dispatcher/registry.hpp"
#include "ck_tile/dispatcher/dispatcher.hpp"
#include "ck_tile/dispatcher/json_export.hpp"
#include "ck_tile/dispatcher/arch_filter.hpp"
#include "ck_tile/dispatcher/backends/tile_backend.hpp"
#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp"
#include "ck_tile/dispatcher/utils.hpp"
// Grouped Convolution
#include "ck_tile/dispatcher/grouped_conv_config.hpp"
#include "ck_tile/dispatcher/grouped_conv_problem.hpp"
#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp"
#include "ck_tile/dispatcher/grouped_conv_registry.hpp"
#include "ck_tile/dispatcher/grouped_conv_utils.hpp"

View File

@@ -1,6 +1,6 @@
# CK Tile Dispatcher - C++ Headers
C++ API for the CK Tile dispatcher.
C++ API for the CK Tile dispatcher (GEMM and Grouped Convolution).
> **See also:** [Main Dispatcher README](../../../../README.md) for installation and core concepts.
@@ -8,16 +8,25 @@ C++ API for the CK Tile dispatcher.
```
dispatcher/
├── dispatcher.hpp # Main dispatcher (kernel selection)
├── registry.hpp # Kernel registry (storage & lookup)
├── problem.hpp # Problem specification
├── kernel_key.hpp # Kernel configuration key
├── kernel_instance.hpp # Kernel instance interface
├── utils.hpp # Utilities (timers, GPU buffers)
└── backends/ # Backend implementations
├── generated_tile_backend.hpp # CK Tile kernels (production)
└── tile_backend.hpp # Tile backend base
|---- dispatcher.hpp # Main include (includes all below)
|
|---- # GEMM Headers
|---- registry.hpp # Kernel registry (storage & lookup)
|---- problem.hpp # GEMM problem specification
|---- kernel_key.hpp # Kernel configuration key
|---- kernel_instance.hpp # Kernel instance interface
|---- utils.hpp # Utilities (timers, GPU buffers)
|
|---- # Grouped Convolution Headers
|---- grouped_conv_config.hpp # GroupedConvDirection, GroupedConvConfig
|---- grouped_conv_problem.hpp # GroupedConvProblem + ProblemBuilder
|---- grouped_conv_kernel_decl.hpp # GroupedConvKernelDecl, DECL_GROUPED_CONV_KERNEL_SET
|---- grouped_conv_registry.hpp # Thread-safe registry with JSON export & filtering
|---- grouped_conv_utils.hpp # Config creators, validation, benchmark utilities
|
+---- backends/ # Backend implementations
|---- generated_tile_backend.hpp # CK Tile kernels (production)
+---- tile_backend.hpp # Tile backend base
```
## Quick Start
@@ -148,6 +157,69 @@ auto kernel = create_generated_tile_kernel<
>(key, name);
```
## Grouped Convolution API
### GroupedConvProblem (`grouped_conv_problem.hpp`)
Problem specification with builder pattern:
```cpp
#include "ck_tile/dispatcher/grouped_conv_problem.hpp"
using namespace ck_tile::dispatcher;
auto problem = GroupedConvProblemBuilder()
.n(2).g(1).c(128).k(256)
.input_spatial({28, 28})
.filter_spatial({3, 3})
.strides({1, 1})
.dilations({1, 1})
.left_pads({1, 1})
.right_pads({1, 1})
.build();
bool ok = problem.is_valid();
```
### GroupedConvRegistry (`grouped_conv_registry.hpp`)
Thread-safe registry with JSON export and filtering:
```cpp
#include "ck_tile/dispatcher/grouped_conv_registry.hpp"
auto& registry = GroupedConvRegistry::instance();
// Thread-safe registration
registry.register_kernel(kernel);
// JSON export
std::string json = registry.export_json();
registry.export_json_to_file("kernels.json");
// Filtering
auto gfx942_kernels = registry.filter_by_arch("gfx942");
auto matched = registry.filter([](const auto& k) { return k.is_fwd(); });
```
### DECL_GROUPED_CONV_KERNEL_SET (`grouped_conv_kernel_decl.hpp`)
Declarative kernel definition:
```cpp
DECL_GROUPED_CONV_KERNEL_SET(my_conv_kernels,
.add(
GroupedConvSignature().dtype("fp16").layout("nhwgc"),
GroupedConvAlgorithm().tile(128, 128, 32).wave(2, 2, 1)
.warp(32, 32, 16).pipeline("compv4"),
"gfx942"
)
);
// Register all matching current arch
DECL_GROUPED_CONV_KERNEL_ALL(all_conv_kernels, "gfx942");
```
## Best Practices
1. Use `Release` build for performance
@@ -155,6 +227,8 @@ auto kernel = create_generated_tile_kernel<
3. Use `Priority::High` for hand-tuned kernels
4. Reuse dispatcher instances
5. Clear registry between test runs
6. Use `GroupedConvProblemBuilder` for validated problem construction
7. Leverage `export_json()` for kernel inventory and debugging
---

View File

@@ -0,0 +1,152 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Generated Convolution Kernel Backend
//
// Wraps CK Tile grouped convolution launchers for use through the
// GroupedConvDispatcher. Each generated kernel launcher is wrapped in
// a ConvKernelRunFn that builds the correct host-args type (forward,
// bwd-data, or bwd-weight) and calls Launcher::launch().
#pragma once
#include "ck_tile/dispatcher/grouped_conv_problem.hpp"
#include "ck_tile/dispatcher/grouped_conv_registry.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include <hip/hip_runtime.h>
#include <functional>
namespace ck_tile {
namespace dispatcher {
namespace backends {
// Buffer context is defined in grouped_conv_registry.hpp (g_conv_dispatch_buffers)
// so there's no circular dependency.
// Helper: build ck_tile::conv::ConvParam from GroupedConvProblem
inline ck_tile::conv::ConvParam make_conv_param_2d(const GroupedConvProblem& p)
{
return ck_tile::conv::ConvParam{
2,
static_cast<ck_tile::index_t>(p.G),
static_cast<ck_tile::index_t>(p.N),
static_cast<ck_tile::index_t>(p.K),
static_cast<ck_tile::index_t>(p.C),
{static_cast<ck_tile::index_t>(p.filter_spatial[1]),
static_cast<ck_tile::index_t>(p.filter_spatial[2])},
{static_cast<ck_tile::index_t>(p.input_spatial[1]),
static_cast<ck_tile::index_t>(p.input_spatial[2])},
{static_cast<ck_tile::index_t>(p.stride[1]), static_cast<ck_tile::index_t>(p.stride[2])},
{static_cast<ck_tile::index_t>(p.dilation[1]),
static_cast<ck_tile::index_t>(p.dilation[2])},
{static_cast<ck_tile::index_t>(p.padding[1]), static_cast<ck_tile::index_t>(p.padding[2])},
{static_cast<ck_tile::index_t>(p.padding[1]), static_cast<ck_tile::index_t>(p.padding[2])}};
}
inline ck_tile::conv::ConvParam make_conv_param_3d(const GroupedConvProblem& p)
{
return ck_tile::conv::ConvParam{3,
static_cast<ck_tile::index_t>(p.G),
static_cast<ck_tile::index_t>(p.N),
static_cast<ck_tile::index_t>(p.K),
static_cast<ck_tile::index_t>(p.C),
{static_cast<ck_tile::index_t>(p.filter_spatial[0]),
static_cast<ck_tile::index_t>(p.filter_spatial[1]),
static_cast<ck_tile::index_t>(p.filter_spatial[2])},
{static_cast<ck_tile::index_t>(p.input_spatial[0]),
static_cast<ck_tile::index_t>(p.input_spatial[1]),
static_cast<ck_tile::index_t>(p.input_spatial[2])},
{static_cast<ck_tile::index_t>(p.stride[0]),
static_cast<ck_tile::index_t>(p.stride[1]),
static_cast<ck_tile::index_t>(p.stride[2])},
{static_cast<ck_tile::index_t>(p.dilation[0]),
static_cast<ck_tile::index_t>(p.dilation[1]),
static_cast<ck_tile::index_t>(p.dilation[2])},
{static_cast<ck_tile::index_t>(p.padding[0]),
static_cast<ck_tile::index_t>(p.padding[1]),
static_cast<ck_tile::index_t>(p.padding[2])},
{static_cast<ck_tile::index_t>(p.padding[0]),
static_cast<ck_tile::index_t>(p.padding[1]),
static_cast<ck_tile::index_t>(p.padding[2])}};
}
// Create a RunFn for a forward convolution launcher (2D or 3D)
template <typename LauncherType, int NDim>
inline GroupedConvKernelInstance::RunFn make_conv_fwd_run_fn()
{
return [](const GroupedConvProblem& problem, void* stream) -> float {
auto& ctx = g_conv_dispatch_buffers;
auto param = (NDim == 2) ? make_conv_param_2d(problem) : make_conv_param_3d(problem);
ck_tile::GroupedConvFwdHostArgs<> args(
param, ctx.input_ptr, ctx.weight_ptr, {}, ctx.output_ptr, 1);
ck_tile::stream_config sc;
sc.stream_id_ = reinterpret_cast<hipStream_t>(stream);
sc.time_kernel_ = ctx.benchmarking;
sc.log_level_ = 0;
sc.cold_niters_ = ctx.benchmarking ? ctx.warmup : 0;
sc.nrepeat_ = ctx.benchmarking ? ctx.repeat : 1;
sc.is_gpu_timer_ = ctx.benchmarking;
return LauncherType::launch(args, sc);
};
}
// Create a RunFn for a backward-data convolution launcher.
// Dispatcher convention: run(dY, W, dX, problem) where dX is computed.
// BwdDataHostArgs(param, in_ptr=dX, wei_ptr=W, {}, out_ptr=dY, k_batch)
template <typename LauncherType, int NDim>
inline GroupedConvKernelInstance::RunFn make_conv_bwd_data_run_fn()
{
return [](const GroupedConvProblem& problem, void* stream) -> float {
auto& ctx = g_conv_dispatch_buffers;
auto param = (NDim == 2) ? make_conv_param_2d(problem) : make_conv_param_3d(problem);
ck_tile::GroupedConvBwdDataHostArgs args(
param,
ctx.output_ptr, // in_ptr = dX (being computed)
ctx.weight_ptr, // wei_ptr = W
{},
ctx.input_ptr, // out_ptr = dY (gradient from next layer)
1);
ck_tile::stream_config sc;
sc.stream_id_ = reinterpret_cast<hipStream_t>(stream);
sc.time_kernel_ = ctx.benchmarking;
sc.log_level_ = 0;
sc.cold_niters_ = ctx.benchmarking ? ctx.warmup : 0;
sc.nrepeat_ = ctx.benchmarking ? ctx.repeat : 1;
sc.is_gpu_timer_ = ctx.benchmarking;
return LauncherType::launch(args, sc);
};
}
// Create a RunFn for a backward-weight convolution launcher.
// Dispatcher convention: run(X, dY, dW, problem) where dW is computed.
// BwdWeightHostArgs(param, in_ptr=X, wei_ptr=dW, {}, out_ptr=dY, k_batch)
template <typename LauncherType, int NDim>
inline GroupedConvKernelInstance::RunFn make_conv_bwd_weight_run_fn()
{
return [](const GroupedConvProblem& problem, void* stream) -> float {
auto& ctx = g_conv_dispatch_buffers;
auto param = (NDim == 2) ? make_conv_param_2d(problem) : make_conv_param_3d(problem);
const int k_batch = (ctx.split_k > 1) ? ctx.split_k : 1;
ck_tile::GroupedConvBwdWeightHostArgs args(param,
ctx.input_ptr, // in_ptr = X
ctx.output_ptr, // wei_ptr = dW (being computed)
{},
ctx.weight_ptr, // out_ptr = dY
k_batch);
ck_tile::stream_config sc;
sc.stream_id_ = reinterpret_cast<hipStream_t>(stream);
sc.time_kernel_ = ctx.benchmarking;
sc.log_level_ = 0;
sc.cold_niters_ = ctx.benchmarking ? ctx.warmup : 0;
sc.nrepeat_ = ctx.benchmarking ? ctx.repeat : 1;
sc.is_gpu_timer_ = ctx.benchmarking;
return LauncherType::launch(args, sc);
};
}
} // namespace backends
} // namespace dispatcher
} // namespace ck_tile

View File

@@ -0,0 +1,199 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <atomic>
#include <functional>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
namespace ck_tile {
namespace dispatcher {
/// Shared priority enum used by all registry types
enum class Priority
{
Low = 0,
Normal = 1,
High = 2
};
/// BaseRegistry: Thread-safe, priority-aware kernel storage shared by GEMM and Conv registries.
///
/// Template Parameters:
/// Derived - CRTP derived class (e.g., Registry, ConvRegistry)
/// KeyType - primary key type (std::string for GEMM, ConvKernelKey for Conv)
/// InstanceType - kernel instance type (KernelInstance, ConvKernelInstance)
/// KeyHash - hash functor for KeyType (defaults to std::hash<KeyType>)
template <typename Derived,
typename KeyType,
typename InstanceType,
typename KeyHash = std::hash<KeyType>>
class BaseRegistry
{
public:
using InstancePtr = std::shared_ptr<InstanceType>;
struct Entry
{
InstancePtr instance;
Priority priority;
};
BaseRegistry() = default;
virtual ~BaseRegistry() = default;
BaseRegistry(BaseRegistry&& other) noexcept
{
std::lock_guard<std::mutex> lock(other.mutex_);
entries_ = std::move(other.entries_);
name_ = std::move(other.name_);
}
BaseRegistry& operator=(BaseRegistry&& other) noexcept
{
if(this != &other)
{
std::scoped_lock lock(mutex_, other.mutex_);
entries_ = std::move(other.entries_);
name_ = std::move(other.name_);
}
return *this;
}
BaseRegistry(const BaseRegistry&) = delete;
BaseRegistry& operator=(const BaseRegistry&) = delete;
/// Register a kernel. If the key already exists, the new entry replaces it
/// unless the existing entry has strictly higher priority.
/// Same-priority registration overwrites (last-writer-wins at equal priority).
bool
register_kernel(const KeyType& key, InstancePtr instance, Priority priority = Priority::Normal)
{
std::lock_guard<std::mutex> lock(mutex_);
auto it = entries_.find(key);
if(it != entries_.end() && it->second.priority > priority)
{
return false;
}
entries_[key] = Entry{std::move(instance), priority};
return true;
}
[[nodiscard]] std::size_t size() const
{
std::lock_guard<std::mutex> lock(mutex_);
return entries_.size();
}
[[nodiscard]] bool empty() const
{
std::lock_guard<std::mutex> lock(mutex_);
return entries_.empty();
}
void clear()
{
std::lock_guard<std::mutex> lock(mutex_);
entries_.clear();
}
[[nodiscard]] std::string get_name() const
{
std::lock_guard<std::mutex> lock(mutex_);
return name_; // return by value to avoid dangling reference
}
void set_name(const std::string& name)
{
std::lock_guard<std::mutex> lock(mutex_);
name_ = name;
}
[[nodiscard]] std::vector<InstancePtr> get_all_instances() const
{
std::lock_guard<std::mutex> lock(mutex_);
std::vector<InstancePtr> result;
result.reserve(entries_.size());
for(const auto& [key, entry] : entries_)
{
result.push_back(entry.instance);
}
return result;
}
std::size_t merge_from(const BaseRegistry& other, Priority priority = Priority::Normal)
{
std::scoped_lock lock(mutex_, other.mutex_);
std::size_t merged = 0;
for(const auto& [key, entry] : other.entries_)
{
auto it = entries_.find(key);
if(it == entries_.end() || it->second.priority <= priority)
{
entries_[key] = Entry{entry.instance, priority};
++merged;
}
}
return merged;
}
/// Enable automatic JSON export after every kernel registration.
/// Requires the derived class to implement export_json_to_file(path, stats).
void enable_auto_export(const std::string& path,
bool include_statistics = true,
bool export_on_every_registration = true)
{
std::lock_guard<std::mutex> lock(mutex_);
auto_export_path_ = path;
auto_export_stats_ = include_statistics;
auto_export_on_register_ = export_on_every_registration;
auto_export_enabled_.store(true, std::memory_order_release);
}
void disable_auto_export() { auto_export_enabled_.store(false, std::memory_order_release); }
[[nodiscard]] bool is_auto_export_enabled() const
{
return auto_export_enabled_.load(std::memory_order_acquire);
}
/// Call after registration to trigger auto-export if enabled.
void perform_auto_export()
{
if(!auto_export_enabled_.load(std::memory_order_acquire))
return;
std::lock_guard<std::mutex> lock(mutex_);
if(auto_export_on_register_)
{
static_cast<Derived*>(this)->export_json_to_file(auto_export_path_, auto_export_stats_);
}
}
protected:
[[nodiscard]] const std::unordered_map<KeyType, Entry, KeyHash>& entries() const
{
return entries_;
}
[[nodiscard]] std::unordered_map<KeyType, Entry, KeyHash>& entries_mut() { return entries_; }
std::mutex& mutex() const { return mutex_; }
private:
mutable std::mutex mutex_;
std::unordered_map<KeyType, Entry, KeyHash> entries_;
std::string name_ = "default";
std::atomic<bool> auto_export_enabled_{false};
bool auto_export_on_register_ = true;
bool auto_export_stats_ = true;
std::string auto_export_path_;
};
} // namespace dispatcher
} // namespace ck_tile

View File

@@ -23,6 +23,7 @@
#pragma once
#include "ck_tile/dispatcher/dispatcher_error.hpp"
#include "ck_tile/dispatcher/kernel_instance.hpp"
#include "ck_tile/dispatcher/problem.hpp"
#include "ck_tile/dispatcher/registry.hpp"
@@ -52,7 +53,11 @@ class Dispatcher
/// Constructor
/// @param registry Registry instance to use (default: global singleton)
explicit Dispatcher(Registry* registry = nullptr);
/// @param gfx_arch Target GPU architecture (e.g. "gfx950")
explicit Dispatcher(Registry* registry = nullptr, const std::string& gfx_arch = "");
void set_arch(const std::string& arch) { gfx_arch_ = arch; }
[[nodiscard]] const std::string& arch() const { return gfx_arch_; }
/// Register a heuristic function for kernel selection
/// @param heuristic Function that maps problems to ranked kernel identifiers
@@ -74,7 +79,7 @@ class Dispatcher
/// @param problem Problem configuration
/// @param stream HIP stream for kernel launch (nullptr = default stream)
/// @return Kernel execution time in milliseconds
/// @throws std::runtime_error if no suitable kernel found
/// @throws NoKernelFound if no suitable kernel found
[[nodiscard]] float run(const void* a_ptr,
const void* b_ptr,
void* c_ptr,
@@ -89,7 +94,7 @@ class Dispatcher
/// @param problem Problem configuration
/// @param stream HIP stream for kernel launch (nullptr = default stream)
/// @return Kernel execution time in milliseconds
/// @throws std::runtime_error if no suitable kernel found
/// @throws NoKernelFound if no suitable kernel found
[[nodiscard]] float run_fused(const void* a_ptr,
const void* b_ptr,
void* c_ptr,
@@ -106,7 +111,8 @@ class Dispatcher
/// @param problem Problem configuration
/// @param stream HIP stream for kernel launch (nullptr = default stream)
/// @return Kernel execution time in milliseconds
/// @throws std::runtime_error if kernel not found or doesn't support problem
/// @throws NoKernelFound if the kernel identifier is not registered
/// @throws UnsupportedProblem if the selected kernel does not support the problem
[[nodiscard]] float run_explicit(const std::string& kernel_id,
const void* a_ptr,
const void* b_ptr,
@@ -130,10 +136,18 @@ class Dispatcher
const Problem& problem,
float tolerance = 1e-3f) const;
/// Enable or disable GPU benchmarking (timing) on all kernels.
/// When disabled, kernels execute once with no timing overhead
/// (one-shot mode for production plugins).
void set_benchmarking(bool enable) { benchmarking_ = enable; }
[[nodiscard]] bool benchmarking_enabled() const { return benchmarking_; }
private:
Registry* registry_;
HeuristicFunction heuristic_;
SelectionStrategy strategy_;
std::string gfx_arch_;
bool benchmarking_ = true;
/// Select kernel using first-fit strategy
[[nodiscard]] KernelInstancePtr select_first_fit(const Problem& problem) const;

View File

@@ -0,0 +1,28 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <stdexcept>
#include <string>
namespace ck_tile {
namespace dispatcher {
struct DispatcherError : std::runtime_error
{
using std::runtime_error::runtime_error;
};
struct NoKernelFound : DispatcherError
{
using DispatcherError::DispatcherError;
};
struct UnsupportedProblem : DispatcherError
{
using DispatcherError::DispatcherError;
};
} // namespace dispatcher
} // namespace ck_tile

View File

@@ -0,0 +1,55 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdlib>
#include <iostream>
#include <string>
namespace ck_tile {
namespace dispatcher {
/// Log levels for dispatcher transparency:
/// 0 = silent (default)
/// 1 = print selected kernel name
/// 2 = print all candidates considered and acceptance/rejection reasons
inline int get_log_level()
{
static int level = []() {
const char* env = std::getenv("CK_DISPATCHER_LOG_LEVEL");
return env ? std::atoi(env) : 0;
}();
return level;
}
inline void log_kernel_selected(const std::string& kernel_name, const std::string& problem_desc)
{
if(get_log_level() >= 1)
{
std::cerr << "[CK Dispatcher] Selected kernel: " << kernel_name << " for " << problem_desc
<< std::endl;
}
}
inline void
log_kernel_candidate(const std::string& kernel_name, bool accepted, const std::string& reason)
{
if(get_log_level() >= 2)
{
std::cerr << "[CK Dispatcher] Candidate: " << kernel_name << " -> "
<< (accepted ? "ACCEPTED" : "REJECTED")
<< (reason.empty() ? "" : " (" + reason + ")") << std::endl;
}
}
inline void log_no_kernel_found(const std::string& problem_desc)
{
if(get_log_level() >= 1)
{
std::cerr << "[CK Dispatcher] No kernel found for " << problem_desc << std::endl;
}
}
} // namespace dispatcher
} // namespace ck_tile

View File

@@ -0,0 +1,588 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* @file grouped_conv_config.hpp
* @brief CK Tile Grouped Convolution Configuration with Builder-style naming
*
* This adopts the Signature/Algorithm/Arch pattern from:
* experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp
*
* Structure:
* - Signature: WHAT operation (types, layouts, direction, element ops)
* - Algorithm: HOW it's computed (tiles, warps, pipeline, scheduler, padding)
* - Arch: Target GPU architecture
*/
#pragma once
// Use common kernel_key types for DataType, Pipeline, etc.
#include "ck_tile/dispatcher/kernel_key.hpp"
#include <string>
#include <sstream>
#include <array>
#include <cstdint>
namespace ck_tile {
namespace dispatcher {
// DataType, Pipeline, Scheduler, Epilogue are defined in kernel_key.hpp
// No need to redefine them here
// =============================================================================
// Data Type Enum (matching CK Tile numeric types)
// =============================================================================
enum class ConvDataType
{
// Standard floating point
FP32, // float
FP64, // double
FP16, // half_t
BF16, // bf16_t
// 8-bit float variants (FP8/BF8)
FP8, // fp8_t (E4M3)
BF8, // bf8_t (E5M2)
FP8_E4M3, // Explicit E4M3 format
FP8_E5M2, // Explicit E5M2 format
// Integer types
INT8, // int8_t
UINT8, // uint8_t
INT32, // int32_t (accumulator)
// 4-bit types (gfx950+ only)
FP4, // MXFP4
INT4 // pk_int4_t
};
// =============================================================================
// Direction and Layout Enums
// =============================================================================
enum class GroupedConvDirection
{
FORWARD,
BACKWARD_DATA,
BACKWARD_WEIGHT
};
enum class ConvLayout2D
{
GNHWC_GKYXC_GNHWK, // NHWC-style
NHWGC_GKYXC_NHWGK,
NGCHW_GKYXC_NGKHW, // NCHW-style
NGCHW_GKCYX_NGKHW
};
enum class ConvLayout3D
{
GNDHWC_GKZYXC_GNDHWK,
NDHWGC_GKZYXC_NDHWGK,
NGCDHW_GKZYXC_NGKDHW,
NGCDHW_GKCZYX_NGKDHW
};
// =============================================================================
// Element-wise Operations
// =============================================================================
enum class ElementwiseOp
{
PASS_THROUGH,
BIAS,
BIAS_CLAMP,
SCALE,
BILINEAR,
RELU,
GELU,
SIGMOID,
TANH
};
// =============================================================================
// Grouped Convolution Specialization
// =============================================================================
enum class ConvSpecialization
{
DEFAULT,
FILTER_1X1_PAD0,
FILTER_1X1_STRIDE1_PAD0,
FILTER_3X3,
FILTER_5X5,
FILTER_7X7
};
// =============================================================================
// Memory Operation Types (for accumulator operations)
// =============================================================================
enum class MemoryOperation
{
SET, // Direct write (=)
ATOMIC_ADD, // Atomic addition (+=)
ATOMIC_MAX, // Atomic max
ADD // Non-atomic addition
};
// =============================================================================
// Epilogue Types
// =============================================================================
enum class EpilogueType
{
CSHUFFLE, // C-shuffle epilogue
DEFAULT_2D, // Default 2D epilogue
DEFAULT_GEMM_2D, // Default GEMM 2D epilogue
DIRECT_STORE, // Direct store without shuffle
BIAS_ADD, // Add bias
BIAS_ADD_RELU, // Add bias + ReLU
BIAS_ADD_GELU // Add bias + GELU
};
// =============================================================================
// Algorithm Enums (matching builder/types.hpp and CK Tile pipelines)
// =============================================================================
enum class PipelineVersion
{
V1, // Basic pipeline V1
V2, // Basic pipeline V2
V3, // Compute V3 (intrawave only)
V4, // Compute V4 (double buffer, ping-pong LDS)
V5, // Compute V5 (wave groups)
V6, // Compute V6 (newest)
MEMORY, // Memory pipeline
COMPUTE_ASYNC, // Compute with async copy
PRESHUFFLE_V2 // Preshuffle V2 pipeline
};
enum class PipelineScheduler
{
DEFAULT,
INTRAWAVE,
INTERWAVE
};
enum class GemmPadding
{
DEFAULT,
NO_PADDING, // No padding
M_PADDING,
N_PADDING,
K_PADDING,
MN_PADDING,
MK_PADDING,
NK_PADDING,
MNK_PADDING
};
// =============================================================================
// Signature Info (WHAT operation)
// =============================================================================
struct GroupedConvSignatureInfo
{
int spatial_dim = 2; // 1, 2, or 3
GroupedConvDirection direction = GroupedConvDirection::FORWARD;
std::string in_type = "fp16";
std::string wei_type = "fp16";
std::string out_type = "fp16";
std::string acc_type = "fp32";
std::string workspace_type = "fp32"; // For two-stage algorithms
std::string bias_type = "fp16"; // For bias epilogue
ElementwiseOp in_element_op = ElementwiseOp::PASS_THROUGH;
ElementwiseOp wei_element_op = ElementwiseOp::PASS_THROUGH;
ElementwiseOp out_element_op = ElementwiseOp::PASS_THROUGH;
ConvSpecialization conv_spec = ConvSpecialization::DEFAULT;
int num_groups = 1;
// String helpers
static const char* direction_str(GroupedConvDirection dir)
{
switch(dir)
{
case GroupedConvDirection::FORWARD: return "fwd";
case GroupedConvDirection::BACKWARD_DATA: return "bwd_data";
case GroupedConvDirection::BACKWARD_WEIGHT: return "bwd_weight";
default: return "unknown";
}
}
static const char* datatype_str(ConvDataType dt)
{
switch(dt)
{
case ConvDataType::FP32: return "fp32";
case ConvDataType::FP64: return "fp64";
case ConvDataType::FP16: return "fp16";
case ConvDataType::BF16: return "bf16";
case ConvDataType::FP8: return "fp8";
case ConvDataType::BF8: return "bf8";
case ConvDataType::FP8_E4M3: return "fp8_e4m3";
case ConvDataType::FP8_E5M2: return "fp8_e5m2";
case ConvDataType::INT8: return "int8";
case ConvDataType::UINT8: return "uint8";
case ConvDataType::INT32: return "int32";
case ConvDataType::FP4: return "fp4";
case ConvDataType::INT4: return "int4";
default: return "unknown";
}
}
};
// =============================================================================
// Algorithm Info (HOW it's computed)
// =============================================================================
struct DataTileInfo
{
int m = 128; // M tile (output spatial * N)
int n = 128; // N tile (K output channels)
int k = 64; // K tile (C input channels)
};
struct WarpGemmParams
{
int gemm_m = 16; // MFMA M dimension (MPerXDL)
int gemm_n = 16; // MFMA N dimension (NPerXDL)
int m_iter = 2; // M iterations per warp (MXdlPerWave)
int n_iter = 2; // N iterations per warp (NXdlPerWave)
};
struct BlockWarpConfig
{
int m_warp = 2; // Warps along M
int n_warp = 2; // Warps along N
int k_warp = 1; // Warps along K
int m_warp_tile = 32; // Warp tile M
int n_warp_tile = 32; // Warp tile N
int k_warp_tile = 16; // Warp tile K
};
struct VectorSizeInfo
{
int a = 4; // Input vector size
int b = 8; // Weight vector size
int c = 8; // Output vector size
};
struct GroupedConvAlgorithmInfo
{
DataTileInfo tile;
BlockWarpConfig warp;
VectorSizeInfo vector_size;
PipelineVersion pipeline = PipelineVersion::V4;
PipelineScheduler scheduler = PipelineScheduler::INTRAWAVE;
GemmPadding padding = GemmPadding::MNK_PADDING;
MemoryOperation memory_op = MemoryOperation::SET;
EpilogueType epilogue = EpilogueType::CSHUFFLE;
int thread_block_size = 256;
bool double_smem_buffer = false;
int num_wave_groups = 1;
int block_per_cu = 1;
int num_groups_to_merge = 1;
// Pipeline string
static const char* pipeline_str(PipelineVersion pv)
{
switch(pv)
{
case PipelineVersion::V1: return "v1";
case PipelineVersion::V2: return "v2";
case PipelineVersion::V3: return "compv3";
case PipelineVersion::V4: return "compv4";
case PipelineVersion::V5: return "compv5";
case PipelineVersion::V6: return "compv6";
case PipelineVersion::MEMORY: return "mem";
case PipelineVersion::COMPUTE_ASYNC: return "comp_async";
case PipelineVersion::PRESHUFFLE_V2: return "preshuffle_v2";
default: return "unknown";
}
}
static const char* scheduler_str(PipelineScheduler ps)
{
switch(ps)
{
case PipelineScheduler::DEFAULT: return "default";
case PipelineScheduler::INTRAWAVE: return "intrawave";
case PipelineScheduler::INTERWAVE: return "interwave";
default: return "unknown";
}
}
static const char* memory_op_str(MemoryOperation mo)
{
switch(mo)
{
case MemoryOperation::SET: return "set";
case MemoryOperation::ATOMIC_ADD: return "atomic_add";
case MemoryOperation::ATOMIC_MAX: return "atomic_max";
case MemoryOperation::ADD: return "add";
default: return "unknown";
}
}
static const char* epilogue_str(EpilogueType et)
{
switch(et)
{
case EpilogueType::CSHUFFLE: return "cshuffle";
case EpilogueType::DEFAULT_2D: return "default_2d";
case EpilogueType::DEFAULT_GEMM_2D: return "default_gemm_2d";
case EpilogueType::DIRECT_STORE: return "direct_store";
case EpilogueType::BIAS_ADD: return "bias_add";
case EpilogueType::BIAS_ADD_RELU: return "bias_add_relu";
case EpilogueType::BIAS_ADD_GELU: return "bias_add_gelu";
default: return "unknown";
}
}
};
// =============================================================================
// Arch Info (Target GPU)
// =============================================================================
struct ArchInfo
{
std::string name = "gfx942"; // MI300X default
int max_waves_per_cu = 8;
int lds_size_kb = 64;
int sgpr_count = 108;
int vgpr_count = 512;
bool supports_mfma_fp16() const { return name.find("gfx9") != std::string::npos; }
bool supports_wmma() const { return name.find("gfx11") != std::string::npos; }
};
// =============================================================================
// Full Grouped Conv Config (combines Signature + Algorithm + Arch)
// =============================================================================
struct GroupedConvConfig
{
GroupedConvSignatureInfo signature;
GroupedConvAlgorithmInfo algorithm;
ArchInfo arch;
// Generate unique kernel name
std::string name() const
{
std::ostringstream oss;
oss << "grouped_conv_" << GroupedConvSignatureInfo::direction_str(signature.direction)
<< "_" << signature.in_type << "_" << signature.spatial_dim << "d" << "_"
<< GroupedConvAlgorithmInfo::pipeline_str(algorithm.pipeline) << "_" << algorithm.tile.m
<< "x" << algorithm.tile.n << "x" << algorithm.tile.k;
return oss.str();
}
// Brief description
std::string brief() const
{
std::ostringstream oss;
oss << signature.spatial_dim << "D "
<< GroupedConvSignatureInfo::direction_str(signature.direction)
<< " Grouped Convolution (" << signature.in_type << ")";
return oss.str();
}
// Detailed description (tree-like)
std::string detailed() const
{
std::ostringstream oss;
oss << signature.spatial_dim << "D "
<< GroupedConvSignatureInfo::direction_str(signature.direction)
<< " Grouped Convolution Kernel\n";
oss << " Signature:\n";
oss << " Data Type: " << signature.in_type << "\n";
oss << " Accumulator: " << signature.acc_type << "\n";
oss << " Groups: " << signature.num_groups << "\n";
oss << " Algorithm:\n";
oss << " Thread Block Size: " << algorithm.thread_block_size << "\n";
oss << " Data Tile: " << algorithm.tile.m << "x" << algorithm.tile.n << "x"
<< algorithm.tile.k << "\n";
oss << " Warp Config: " << algorithm.warp.m_warp << "x" << algorithm.warp.n_warp << "x"
<< algorithm.warp.k_warp << "\n";
oss << " Warp Tile: " << algorithm.warp.m_warp_tile << "x" << algorithm.warp.n_warp_tile
<< "x" << algorithm.warp.k_warp_tile << "\n";
oss << " Pipeline: " << GroupedConvAlgorithmInfo::pipeline_str(algorithm.pipeline)
<< "\n";
oss << " Scheduler: " << GroupedConvAlgorithmInfo::scheduler_str(algorithm.scheduler)
<< "\n";
oss << " Arch:\n";
oss << " Target: " << arch.name << "\n";
return oss.str();
}
};
// =============================================================================
// Predefined Configs
// =============================================================================
namespace configs {
// Memory-bound config
template <typename PrecType>
struct Memory : public GroupedConvConfig
{
Memory()
{
algorithm.tile = {128, 32, 128 / (int)sizeof(PrecType)};
algorithm.warp = {4, 1, 1, 32, 32, 16};
algorithm.pipeline = PipelineVersion::MEMORY;
algorithm.double_smem_buffer = false;
}
};
// Compute V3 - Small
template <typename PrecType>
struct CompV3_Small : public GroupedConvConfig
{
CompV3_Small()
{
algorithm.tile = {16, 64, 64};
algorithm.warp = {1, 4, 1, 16, 16, 32};
algorithm.pipeline = PipelineVersion::V3;
}
};
// Compute V3 - Medium
template <typename PrecType>
struct CompV3_Medium : public GroupedConvConfig
{
CompV3_Medium()
{
algorithm.tile = {128, 128, 128 / (int)sizeof(PrecType)};
algorithm.warp = {2, 2, 1, 16, 16, 32};
algorithm.pipeline = PipelineVersion::V3;
algorithm.block_per_cu = 2;
}
};
// Compute V3 - Large
template <typename PrecType>
struct CompV3_Large : public GroupedConvConfig
{
CompV3_Large()
{
algorithm.tile = {256, 256, 128 / (int)sizeof(PrecType)};
algorithm.warp = {2, 2, 1, 32, 32, 16};
algorithm.pipeline = PipelineVersion::V3;
}
};
// Compute V4 - Double buffered
template <typename PrecType>
struct CompV4 : public GroupedConvConfig
{
CompV4()
{
algorithm.tile = {256, 256, 64 / (int)sizeof(PrecType)};
algorithm.warp = {2, 2, 1, 32, 32, 16};
algorithm.pipeline = PipelineVersion::V4;
algorithm.double_smem_buffer = true;
}
};
// Compute V5 - Wave groups
template <typename PrecType>
struct CompV5 : public GroupedConvConfig
{
CompV5()
{
algorithm.tile = {128, 128, 64 / (int)sizeof(PrecType)};
algorithm.warp = {1, 1, 2, 32, 32, 16};
algorithm.pipeline = PipelineVersion::V5;
algorithm.num_wave_groups = 2;
}
};
// WMMA config for gfx11xx
template <typename PrecType>
struct WMMA : public GroupedConvConfig
{
WMMA()
{
algorithm.tile = {128, 128, 64 / (int)sizeof(PrecType)};
algorithm.warp = {4, 2, 1, 16, 16, 16};
algorithm.pipeline = PipelineVersion::V3;
algorithm.block_per_cu = 2;
arch.name = "gfx1100";
}
};
// Merged groups config
template <typename PrecType>
struct CompV3_MergedGroups : public GroupedConvConfig
{
CompV3_MergedGroups()
{
algorithm.tile = {16, 32, 32};
algorithm.warp = {1, 2, 1, 16, 16, 32};
algorithm.vector_size = {4, 8, 8};
algorithm.pipeline = PipelineVersion::V3;
algorithm.num_groups_to_merge = 2;
}
};
} // namespace configs
// =============================================================================
// DataType Traits (compile-time type info for CK Tile types)
// =============================================================================
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
static constexpr int size_bytes = 4;
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
static constexpr int size_bytes = 8;
};
// Forward declare CK Tile types for traits
// Note: actual ck_tile types are defined in ck_tile/core/numeric/
// These traits allow working with type names at compile time
// =============================================================================
// ConvTypeConfig (input/weight/acc/output type combinations)
// =============================================================================
template <typename InDataType,
typename WeiDataType = InDataType,
typename OutDataType = InDataType,
typename AccDataType = float>
struct ConvTypeConfig
{
using input_type = InDataType;
using weight_type = WeiDataType;
using output_type = OutDataType;
using accumulator_type = AccDataType;
};
// Common type configurations as type aliases
// FP16 -> FP32 accumulator -> FP16 output (most common)
// BF16 -> FP32 accumulator -> BF16 output
// FP8 -> FP32 accumulator -> FP8 output
// INT8 -> INT32 accumulator -> INT8 output
} // namespace dispatcher
} // namespace ck_tile

View File

@@ -0,0 +1,537 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* @file grouped_conv_kernel_decl.hpp
* @brief Declarative grouped convolution kernel specification
*
* USAGE:
* ======
*
* // Named kernel sets for grouped convolution
* DECL_GROUPED_CONV_KERNEL_SET(gconv_fwd,
* .add("fp16", "nhwc", "forward", 128, 128, 32)
* .add("fp16", "nhwc", "forward", 256, 256, 64)
* );
*
* // Access at runtime
* auto& set = GroupedConvKernelSetRegistry::instance().get("gconv_fwd");
*/
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include <unordered_map>
#include <iostream>
#include <sstream>
namespace ck_tile {
namespace dispatcher {
namespace grouped_conv_decl {
// =============================================================================
// Wildcard constants
// =============================================================================
constexpr const char* ANY = "*";
constexpr int ANY_INT = -1;
// =============================================================================
// GroupedConvSignature - WHAT operation
// =============================================================================
class GroupedConvSignature
{
public:
std::string dtype_in_ = "fp16"; // Input data type
std::string dtype_wei_ = "fp16"; // Weight data type
std::string dtype_out_ = "fp16"; // Output data type
std::string dtype_acc_ = "fp32"; // Accumulator type
std::string dtype_workspace_ = "fp32"; // Workspace type (two-stage algorithms)
std::string dtype_bias_ = "fp16"; // Bias type (bias epilogue)
std::string layout_ = "nhwc"; // Data layout: nhwc, nchw
std::string conv_op_ = "forward"; // forward, bwd_data, bwd_weight
int num_dims_ = 2; // Spatial dimensions: 1, 2, or 3
int groups_ = 1; // Group grouped convolution
std::string specialization_ = "default"; // Filter specialization
GroupedConvSignature& dtype(const std::string& in,
const std::string& wei,
const std::string& out,
const std::string& acc = "fp32")
{
dtype_in_ = in;
dtype_wei_ = wei;
dtype_out_ = out;
dtype_acc_ = acc;
return *this;
}
GroupedConvSignature& dtype(const std::string& all)
{
dtype_in_ = dtype_wei_ = dtype_out_ = dtype_bias_ = all;
dtype_acc_ = dtype_workspace_ = "fp32";
return *this;
}
GroupedConvSignature& dtype_workspace(const std::string& ws)
{
dtype_workspace_ = ws;
return *this;
}
GroupedConvSignature& dtype_bias(const std::string& b)
{
dtype_bias_ = b;
return *this;
}
GroupedConvSignature& layout(const std::string& l)
{
layout_ = l;
return *this;
}
GroupedConvSignature& conv_type(const std::string& op)
{
conv_op_ = op;
return *this;
}
GroupedConvSignature& dims(int d)
{
num_dims_ = d;
return *this;
}
GroupedConvSignature& groups(int g)
{
groups_ = g;
return *this;
}
GroupedConvSignature& spec(const std::string& s)
{
specialization_ = s;
return *this;
}
std::string op_str() const
{
if(conv_op_ == "forward")
return "fwd";
if(conv_op_ == "bwd_data")
return "bwd_data";
if(conv_op_ == "bwd_weight")
return "bwd_weight";
return conv_op_;
}
};
// =============================================================================
// GroupedConvAlgorithm - HOW it's implemented
// =============================================================================
class GroupedConvAlgorithm
{
public:
// Tile shape (M, N, K per tile - M=spatial*N, N=K_out, K=C_in)
int tile_m_ = 1; // Tile M (output spatial * batch)
int tile_n_ = 128; // Tile N (output channels K)
int tile_k_ = 128; // Tile K (input channels C)
// Output spatial tile
int tile_ho_ = 1;
int tile_wo_ = 16;
// Wave/warp shape
int wave_m_ = ANY_INT;
int wave_n_ = ANY_INT;
int wave_k_ = 1;
int warp_m_ = ANY_INT;
int warp_n_ = ANY_INT;
int warp_k_ = 16;
// Vector sizes
int vector_a_ = 4; // Input vector size
int vector_b_ = 8; // Weight vector size
int vector_c_ = 8; // Output vector size
// Pipeline configuration
std::string pipeline_ = "compv4";
std::string scheduler_ = "intrawave";
std::string epilogue_ = "cshuffle";
std::string memory_op_ = "set"; // Memory operation: set, atomic_add, atomic_max, add
// Occupancy/performance hints
int block_size_ = 256;
int block_per_cu_ = 1;
int num_wave_groups_ = 1;
int num_groups_to_merge_ = 1;
bool double_smem_buffer_ = false;
// Padding -- always enabled for convolution (MNK padding assumed)
static constexpr bool pad_m_ = true;
static constexpr bool pad_n_ = true;
static constexpr bool pad_k_ = true;
// Tile setter (M, N, K)
GroupedConvAlgorithm& tile(int m, int n, int k)
{
tile_m_ = m;
tile_n_ = n;
tile_k_ = k;
return *this;
}
GroupedConvAlgorithm& tile_output(int ho, int wo)
{
tile_ho_ = ho;
tile_wo_ = wo;
return *this;
}
GroupedConvAlgorithm& wave(int m, int n, int k = 1)
{
wave_m_ = m;
wave_n_ = n;
wave_k_ = k;
return *this;
}
GroupedConvAlgorithm& warp(int m, int n, int k = 16)
{
warp_m_ = m;
warp_n_ = n;
warp_k_ = k;
return *this;
}
GroupedConvAlgorithm& vector_sizes(int a, int b, int c)
{
vector_a_ = a;
vector_b_ = b;
vector_c_ = c;
return *this;
}
GroupedConvAlgorithm& pipeline(const std::string& p)
{
pipeline_ = p;
return *this;
}
GroupedConvAlgorithm& scheduler(const std::string& s)
{
scheduler_ = s;
return *this;
}
GroupedConvAlgorithm& epilogue(const std::string& e)
{
epilogue_ = e;
return *this;
}
GroupedConvAlgorithm& memory_op(const std::string& m)
{
memory_op_ = m;
return *this;
}
// Occupancy setters
GroupedConvAlgorithm& block_per_cu(int b)
{
block_per_cu_ = b;
return *this;
}
GroupedConvAlgorithm& num_wave_groups(int n)
{
num_wave_groups_ = n;
return *this;
}
GroupedConvAlgorithm& num_groups_to_merge(int n)
{
num_groups_to_merge_ = n;
return *this;
}
GroupedConvAlgorithm& double_smem_buffer(bool d)
{
double_smem_buffer_ = d;
return *this;
}
bool needs_expansion() const
{
return wave_m_ == ANY_INT || warp_m_ == ANY_INT || pipeline_ == "*" || scheduler_ == "*";
}
/// Check if specific parameter needs expansion
bool needs_wave_expansion() const { return wave_m_ == ANY_INT || wave_n_ == ANY_INT; }
bool needs_warp_expansion() const { return warp_m_ == ANY_INT || warp_n_ == ANY_INT; }
bool needs_pipeline_expansion() const { return pipeline_ == "*"; }
bool needs_scheduler_expansion() const { return scheduler_ == "*"; }
/// Auto-fill with defaults (for single kernel generation)
void auto_fill()
{
if(wave_m_ == ANY_INT)
wave_m_ = 2;
if(wave_n_ == ANY_INT)
wave_n_ = 2;
if(warp_m_ == ANY_INT)
warp_m_ = 32;
if(warp_n_ == ANY_INT)
warp_n_ = 32;
if(pipeline_ == "*")
pipeline_ = "compv4";
if(scheduler_ == "*")
scheduler_ = "intrawave";
}
/// Get all valid wave configurations for arch
static std::vector<std::tuple<int, int, int>> valid_wave_configs(const std::string& arch)
{
// Match arch_specs_generated.py WARP_SUPPORTED_COMBINATIONS
if(arch == "gfx942" || arch == "gfx90a" || arch == "gfx950")
{
return {{1, 4, 1}, {2, 2, 1}, {4, 1, 1}};
}
return {{2, 2, 1}}; // Default
}
/// Get all valid warp tile configurations
static std::vector<std::tuple<int, int, int>> valid_warp_configs(const std::string& arch,
const std::string& dtype)
{
// Match arch_specs_generated.py WARP_TILE_SUPPORTED_COMBINATIONS
if(arch == "gfx942" && (dtype == "fp16" || dtype == "bf16"))
{
return {{16, 16, 16}, {32, 32, 16}};
}
return {{32, 32, 16}}; // Default
}
/// Get all valid pipeline/scheduler combinations for forward conv.
/// Backward operations (bwd_data/bwd_weight) only support compv3 and mem
/// due to transpose_tile2d and get_length constraints in CK Tile.
static std::vector<std::pair<std::string, std::string>> valid_trait_configs()
{
return {
{"compv3", "intrawave"},
{"compv4", "intrawave"},
{"compv5", "intrawave"},
{"mem", "intrawave"},
{"mem", "interwave"},
};
}
};
// =============================================================================
// GroupedConvKernelDecl
// =============================================================================
struct GroupedConvKernelDecl
{
GroupedConvSignature signature;
GroupedConvAlgorithm algorithm;
std::string arch = "gfx942";
GroupedConvKernelDecl() = default;
GroupedConvKernelDecl(const GroupedConvSignature& sig,
const GroupedConvAlgorithm& algo,
const std::string& a = "gfx942")
: signature(sig), algorithm(algo), arch(a)
{
}
std::string name() const
{
std::ostringstream oss;
// Generate full kernel name similar to GEMM:
// grouped_conv_<op>_<dtype>_<layout>_<ndim>d_<pipeline>_<epilogue>_<scheduler>_<tile>_<wave>_<warp>
oss << "grouped_conv_" << signature.op_str() << "_" << signature.dtype_in_ << "_"
<< signature.layout_ << "_" << signature.num_dims_ << "d" << "_" << algorithm.pipeline_
<< "_" << algorithm.epilogue_ << "_" << algorithm.scheduler_ << "_" << algorithm.tile_m_
<< "x" << algorithm.tile_n_ << "x" << algorithm.tile_k_ << "_" << algorithm.wave_m_
<< "x" << algorithm.wave_n_ << "x" << algorithm.wave_k_ << "_" << algorithm.warp_m_
<< "x" << algorithm.warp_n_ << "x" << algorithm.warp_k_;
return oss.str();
}
bool has_wildcards() const { return algorithm.needs_expansion() || arch == "*"; }
};
// =============================================================================
// GroupedConvKernelSet
// =============================================================================
class GroupedConvKernelSet
{
public:
GroupedConvKernelSet() = default;
GroupedConvKernelSet& add(const GroupedConvSignature& sig,
const GroupedConvAlgorithm& algo,
const std::string& arch = "gfx942")
{
decls_.emplace_back(sig, algo, arch);
return *this;
}
// Simple add: dtype, layout, conv_type, tile_k, tile_c
GroupedConvKernelSet& add(const std::string& dtype,
const std::string& layout,
const std::string& conv_type,
int tile_k,
int tile_c,
const std::string& arch = "gfx942")
{
GroupedConvSignature sig;
sig.dtype(dtype).layout(layout).conv_type(conv_type);
GroupedConvAlgorithm algo;
algo.tile(1, tile_k, tile_c);
decls_.emplace_back(sig, algo, arch);
return *this;
}
GroupedConvKernelSet& merge(const GroupedConvKernelSet& other)
{
decls_.insert(decls_.end(), other.decls_.begin(), other.decls_.end());
return *this;
}
const std::vector<GroupedConvKernelDecl>& declarations() const { return decls_; }
size_t size() const { return decls_.size(); }
void print(std::ostream& os = std::cout) const
{
os << "GroupedConvKernelSet (" << size() << " declarations):\n";
for(const auto& d : decls_)
{
os << " - " << d.name();
if(d.algorithm.needs_expansion())
os << " [expands]";
os << "\n";
}
}
GroupedConvKernelSet& tag(const std::string& t)
{
tag_ = t;
return *this;
}
std::string tag() const { return tag_; }
private:
std::vector<GroupedConvKernelDecl> decls_;
std::string tag_;
};
// =============================================================================
// GroupedConvKernelSetRegistry
// =============================================================================
class GroupedConvKernelSetRegistry
{
public:
static GroupedConvKernelSetRegistry& instance()
{
static GroupedConvKernelSetRegistry reg;
return reg;
}
void add(const std::string& name, const GroupedConvKernelSet& set)
{
sets_[name] = set;
if(std::find(order_.begin(), order_.end(), name) == order_.end())
{
order_.push_back(name);
}
}
// Alias for add() for consistency with GEMM API
void register_set(const std::string& name, const GroupedConvKernelSet& set) { add(name, set); }
const GroupedConvKernelSet& get(const std::string& name) const
{
static GroupedConvKernelSet empty;
auto it = sets_.find(name);
return it != sets_.end() ? it->second : empty;
}
bool has(const std::string& name) const { return sets_.find(name) != sets_.end(); }
std::vector<std::string> names() const { return order_; }
size_t size() const { return sets_.size(); }
void clear()
{
sets_.clear();
order_.clear();
}
void print() const
{
std::cout << "Grouped Conv Kernel Sets (" << size() << "):\n";
for(const auto& name : order_)
{
const auto& set = sets_.at(name);
std::cout << " " << name << ": " << set.size() << " declarations\n";
}
}
private:
GroupedConvKernelSetRegistry() = default;
std::unordered_map<std::string, GroupedConvKernelSet> sets_;
std::vector<std::string> order_;
};
// =============================================================================
// Static Registrar
// =============================================================================
struct GroupedConvKernelSetRegistrar
{
GroupedConvKernelSetRegistrar(const std::string& name, const GroupedConvKernelSet& set)
{
GroupedConvKernelSetRegistry::instance().add(name, set);
}
};
} // namespace grouped_conv_decl
// Convenience aliases
using GroupedConvSignature = grouped_conv_decl::GroupedConvSignature;
using GroupedConvAlgorithm = grouped_conv_decl::GroupedConvAlgorithm;
using GroupedConvKernelDecl = grouped_conv_decl::GroupedConvKernelDecl;
using GroupedConvKernelSet = grouped_conv_decl::GroupedConvKernelSet;
using GroupedConvKernelSetRegistry = grouped_conv_decl::GroupedConvKernelSetRegistry;
} // namespace dispatcher
} // namespace ck_tile
// =============================================================================
// Declaration Macros
// =============================================================================
#define CK_GROUPED_CONV_DECL_CAT_(a, b) CK_GROUPED_CONV_DECL_CAT_IMPL_(a, b)
#define CK_GROUPED_CONV_DECL_CAT_IMPL_(a, b) a##b
// Note: __extension__ suppresses warnings about __COUNTER__ being a GCC/Clang extension
#define DECL_GROUPED_CONV_KERNEL_SET(name, ...) \
__extension__ static ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSetRegistrar \
CK_GROUPED_CONV_DECL_CAT_(_gconv_kset_reg_, __COUNTER__)( \
#name, \
::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSet() __VA_ARGS__.tag(#name))
#define DECL_GROUPED_CONV_KERNEL_ALL(dtype, layout) \
__extension__ static ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSetRegistrar \
CK_GROUPED_CONV_DECL_CAT_(_gconv_kset_reg_, __COUNTER__)( \
#dtype "_" #layout "_all", \
::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSet().add( \
::ck_tile::dispatcher::grouped_conv_decl::GroupedConvSignature().dtype(#dtype).layout( \
#layout), \
::ck_tile::dispatcher::grouped_conv_decl::GroupedConvAlgorithm(), \
"*"))
#define GROUPED_CONV_KERNEL_SET(name) \
::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSet name
#define BEGIN_GROUPED_CONV_KERNEL_SET() \
::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSet()

View File

@@ -0,0 +1,255 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* @file grouped_conv_problem.hpp
* @brief Grouped Convolution problem definition
*/
#pragma once
#include <cstdint>
#include <array>
#include <stdexcept>
#include <string>
namespace ck_tile {
namespace dispatcher {
/**
* @brief Grouped Convolution operation type
*/
enum class GroupedConvOp
{
Forward, // Y = Conv(X, W)
BackwardData, // dX = ConvBwdData(dY, W)
BackwardWeight // dW = ConvBwdWeight(X, dY)
};
/**
* @brief Grouped Convolution problem specification
*/
struct GroupedConvProblem
{
// Batch and channels
std::int64_t N; // Batch size
std::int64_t C; // Input channels
std::int64_t K; // Output channels (filters)
std::int64_t G; // Number of groups (1 for standard conv)
// Spatial dimensions (supports 1D, 2D, 3D)
std::array<std::int64_t, 3> input_spatial; // {D, H, W} or {1, H, W} for 2D
std::array<std::int64_t, 3> filter_spatial; // {Z, Y, X} or {1, Y, X} for 2D
std::array<std::int64_t, 3> output_spatial; // {Do, Ho, Wo} or {1, Ho, Wo} for 2D
// Convolution parameters
std::array<std::int64_t, 3> stride; // Stride in each dimension
std::array<std::int64_t, 3> padding; // Padding in each dimension
std::array<std::int64_t, 3> dilation; // Dilation in each dimension
// Operation type
GroupedConvOp op = GroupedConvOp::Forward;
// Split-K for backward weight (k_batch parameter in CK Tile).
// Values > 1 split the reduction dimension across multiple thread blocks
// and use atomic accumulation.
int split_k = 1;
// Default constructor for 2D convolution
GroupedConvProblem()
: N(1),
C(64),
K(64),
G(1),
input_spatial{1, 28, 28},
filter_spatial{1, 3, 3},
output_spatial{1, 26, 26},
stride{1, 1, 1},
padding{0, 0, 0},
dilation{1, 1, 1},
op(GroupedConvOp::Forward)
{
}
// Constructor for 2D convolution
GroupedConvProblem(std::int64_t n,
std::int64_t c,
std::int64_t k,
std::int64_t hi,
std::int64_t wi,
std::int64_t y,
std::int64_t x,
std::int64_t stride_h = 1,
std::int64_t stride_w = 1,
std::int64_t pad_h = 0,
std::int64_t pad_w = 0,
std::int64_t dilation_h = 1,
std::int64_t dilation_w = 1)
: N(n),
C(c),
K(k),
G(1),
input_spatial{1, hi, wi},
filter_spatial{1, y, x},
stride{1, stride_h, stride_w},
padding{0, pad_h, pad_w},
dilation{1, dilation_h, dilation_w},
op(GroupedConvOp::Forward)
{
compute_output_size();
}
/// Check if problem dimensions are valid
bool is_valid() const
{
return N > 0 && C > 0 && K > 0 && G > 0 && (C % G == 0) && (K % G == 0);
}
/// Compute output spatial dimensions
void compute_output_size()
{
for(int i = 0; i < 3; ++i)
{
std::int64_t effective_filter = (filter_spatial[i] - 1) * dilation[i] + 1;
output_spatial[i] =
(input_spatial[i] + 2 * padding[i] - effective_filter) / stride[i] + 1;
}
}
/// Get 2D height/width accessors
std::int64_t Hi() const { return input_spatial[1]; }
std::int64_t Wi() const { return input_spatial[2]; }
std::int64_t Ho() const { return output_spatial[1]; }
std::int64_t Wo() const { return output_spatial[2]; }
std::int64_t Y() const { return filter_spatial[1]; } // Filter height
std::int64_t X() const { return filter_spatial[2]; } // Filter width
/// Get total FLOPs for this convolution
double get_flops() const
{
// Forward: 2 * N * K * Ho * Wo * C * Y * X / G
double spatial_out = 1.0;
double filter_size = 1.0;
for(int i = 0; i < 3; ++i)
{
spatial_out *= output_spatial[i];
filter_size *= filter_spatial[i];
}
return 2.0 * N * K * spatial_out * (C / G) * filter_size;
}
/// Check if this is a depthwise convolution
bool is_depthwise() const { return G == C && G == K; }
/// Check if this is a pointwise (1x1) convolution
bool is_pointwise() const
{
return filter_spatial[0] == 1 && filter_spatial[1] == 1 && filter_spatial[2] == 1;
}
/// String representation
std::string to_string() const
{
std::string s = "GroupedConvProblem(N=" + std::to_string(N);
s += ", C=" + std::to_string(C) + ", K=" + std::to_string(K);
s += ", G=" + std::to_string(G);
s += ", Hi=" + std::to_string(Hi()) + ", Wi=" + std::to_string(Wi());
s += ", Y=" + std::to_string(Y()) + ", X=" + std::to_string(X());
s += ", Ho=" + std::to_string(Ho()) + ", Wo=" + std::to_string(Wo());
s += ")";
return s;
}
};
// =============================================================================
// GroupedConvProblemBuilder
// =============================================================================
/// Builder pattern for Grouped Convolution problem configuration
class GroupedConvProblemBuilder
{
public:
GroupedConvProblemBuilder() = default;
GroupedConvProblemBuilder& batch(std::int64_t n)
{
problem_.N = n;
return *this;
}
GroupedConvProblemBuilder& channels(std::int64_t c, std::int64_t k)
{
problem_.C = c;
problem_.K = k;
return *this;
}
GroupedConvProblemBuilder& groups(std::int64_t g)
{
problem_.G = g;
return *this;
}
GroupedConvProblemBuilder& input_size(std::int64_t h, std::int64_t w)
{
problem_.input_spatial[0] = 1;
problem_.input_spatial[1] = h;
problem_.input_spatial[2] = w;
return *this;
}
GroupedConvProblemBuilder& filter_size(std::int64_t y, std::int64_t x)
{
problem_.filter_spatial[0] = 1;
problem_.filter_spatial[1] = y;
problem_.filter_spatial[2] = x;
return *this;
}
GroupedConvProblemBuilder& stride(std::int64_t sh, std::int64_t sw)
{
problem_.stride[0] = 1;
problem_.stride[1] = sh;
problem_.stride[2] = sw;
return *this;
}
GroupedConvProblemBuilder& padding(std::int64_t ph, std::int64_t pw)
{
problem_.padding[0] = 0;
problem_.padding[1] = ph;
problem_.padding[2] = pw;
return *this;
}
GroupedConvProblemBuilder& dilation(std::int64_t dh, std::int64_t dw)
{
problem_.dilation[0] = 1;
problem_.dilation[1] = dh;
problem_.dilation[2] = dw;
return *this;
}
GroupedConvProblemBuilder& operation(GroupedConvOp op)
{
problem_.op = op;
return *this;
}
[[nodiscard]] GroupedConvProblem build() const
{
GroupedConvProblem p = problem_;
p.compute_output_size();
if(!p.is_valid())
{
throw std::invalid_argument("Invalid grouped convolution problem dimensions");
}
return p;
}
private:
GroupedConvProblem problem_;
};
} // namespace dispatcher
} // namespace ck_tile

View File

@@ -0,0 +1,614 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* @file grouped_conv_registry.hpp
* @brief Grouped Convolution kernel registry and dispatcher
*/
#pragma once
#include <string>
#include <vector>
#include <unordered_map>
#include <functional>
#include <memory>
#include <stdexcept>
#include <mutex>
#include <fstream>
#include <sstream>
#include <iomanip>
#include <map>
#include "ck_tile/dispatcher/base_registry.hpp"
#include "ck_tile/dispatcher/dispatcher_error.hpp"
#include "ck_tile/dispatcher/grouped_conv_problem.hpp"
#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp"
namespace ck_tile {
namespace dispatcher {
// =============================================================================
// Thread-local buffer context for GroupedConvDispatcher::run()
// The generated conv backend RunFn reads these to get buffer pointers.
// =============================================================================
struct ConvDispatchBuffers
{
const void* input_ptr = nullptr;
const void* weight_ptr = nullptr;
void* output_ptr = nullptr;
int warmup = 3;
int repeat = 10;
bool benchmarking = true;
int split_k = 1;
};
inline thread_local ConvDispatchBuffers g_conv_dispatch_buffers;
// =============================================================================
// GroupedConvKernelKey - Unique identifier for a grouped convolution kernel
// =============================================================================
struct GroupedConvKernelKey
{
// Signature fields
std::string dtype_in;
std::string dtype_wei;
std::string dtype_out;
std::string layout; // e.g., "nhwgc"
int ndim_spatial = 2; // 1, 2, or 3
GroupedConvOp op = GroupedConvOp::Forward;
// Tile configuration
int tile_m = 1;
int tile_n = 128;
int tile_k = 128;
// Wave/warp configuration
int wave_m = 2;
int wave_n = 2;
int wave_k = 1;
int warp_m = 32;
int warp_n = 32;
int warp_k = 16;
// Pipeline
std::string pipeline = "compv3";
std::string scheduler = "intrawave";
std::string epilogue = "cshuffle";
// ConvConfigBase parity fields
int vector_size_a = 4;
int vector_size_b = 8;
int vector_size_c = 8;
int block_per_cu = 1;
int num_wave_groups = 1;
int num_groups_to_merge = 1;
// GPU architecture (for filter_by_arch)
std::string arch = "gfx942";
bool operator==(const GroupedConvKernelKey& other) const
{
return dtype_in == other.dtype_in && dtype_wei == other.dtype_wei &&
dtype_out == other.dtype_out && layout == other.layout &&
ndim_spatial == other.ndim_spatial && op == other.op && tile_m == other.tile_m &&
tile_n == other.tile_n && tile_k == other.tile_k && wave_m == other.wave_m &&
wave_n == other.wave_n && wave_k == other.wave_k && warp_m == other.warp_m &&
warp_n == other.warp_n && warp_k == other.warp_k && pipeline == other.pipeline &&
scheduler == other.scheduler && epilogue == other.epilogue &&
vector_size_a == other.vector_size_a && vector_size_b == other.vector_size_b &&
vector_size_c == other.vector_size_c && block_per_cu == other.block_per_cu &&
num_wave_groups == other.num_wave_groups &&
num_groups_to_merge == other.num_groups_to_merge && arch == other.arch;
}
std::string to_string() const
{
std::string op_str;
switch(op)
{
case GroupedConvOp::Forward: op_str = "fwd"; break;
case GroupedConvOp::BackwardData: op_str = "bwd_data"; break;
case GroupedConvOp::BackwardWeight: op_str = "bwd_weight"; break;
}
return "grouped_conv_" + op_str + "_" + dtype_in + "_" + std::to_string(ndim_spatial) +
"d_" + std::to_string(tile_m) + "x" + std::to_string(tile_n) + "x" +
std::to_string(tile_k) + "_" + std::to_string(wave_m) + "x" +
std::to_string(wave_n) + "x" + std::to_string(wave_k) + "_" +
std::to_string(warp_m) + "x" + std::to_string(warp_n) + "x" +
std::to_string(warp_k) + "_" + pipeline;
}
};
struct GroupedConvKernelKeyHash
{
std::size_t operator()(const GroupedConvKernelKey& key) const
{
std::size_t h = std::hash<std::string>{}(key.dtype_in);
h ^= std::hash<std::string>{}(key.layout) << 1;
h ^= std::hash<int>{}(key.ndim_spatial) << 2;
h ^= std::hash<int>{}(static_cast<int>(key.op)) << 3;
h ^= std::hash<int>{}(key.tile_m) << 4;
h ^= std::hash<int>{}(key.tile_n) << 5;
h ^= std::hash<int>{}(key.tile_k) << 6;
h ^= std::hash<int>{}(key.wave_m) << 7;
h ^= std::hash<int>{}(key.wave_n) << 8;
h ^= std::hash<int>{}(key.warp_m) << 9;
h ^= std::hash<int>{}(key.warp_n) << 10;
h ^= std::hash<std::string>{}(key.pipeline) << 11;
h ^= std::hash<std::string>{}(key.arch) << 12;
return h;
}
};
// =============================================================================
// GroupedConvKernelInstance - Runtime representation of a kernel
// =============================================================================
// Forward declaration for shared_ptr type alias
class GroupedConvKernelInstance;
using GroupedConvKernelInstancePtr = std::shared_ptr<GroupedConvKernelInstance>;
class GroupedConvKernelInstance
{
public:
using RunFn = std::function<float(const GroupedConvProblem&, void*)>;
GroupedConvKernelInstance(const GroupedConvKernelKey& key,
const std::string& name,
RunFn run_fn)
: key_(key), name_(name), run_fn_(std::move(run_fn))
{
}
const GroupedConvKernelKey& key() const { return key_; }
const std::string& name() const { return name_; }
float run(const GroupedConvProblem& problem, void* stream = nullptr) const
{
return run_fn_(problem, stream);
}
bool matches(const GroupedConvProblem& problem) const
{
// Check if this kernel can handle the problem
return problem.op == key_.op;
}
private:
GroupedConvKernelKey key_;
std::string name_;
RunFn run_fn_;
};
// =============================================================================
// GroupedConvRegistry - Stores and manages grouped convolution kernels
// =============================================================================
class GroupedConvRegistry : public BaseRegistry<GroupedConvRegistry,
GroupedConvKernelKey,
GroupedConvKernelInstance,
GroupedConvKernelKeyHash>
{
using Base = BaseRegistry<GroupedConvRegistry,
GroupedConvKernelKey,
GroupedConvKernelInstance,
GroupedConvKernelKeyHash>;
public:
GroupedConvRegistry() = default;
/// Singleton instance for global kernel registration
static GroupedConvRegistry& instance()
{
static GroupedConvRegistry registry;
return registry;
}
/// Register kernels from a GroupedConvKernelSet (atomic batch registration)
bool register_set(const GroupedConvKernelSet& kernel_set, Priority priority = Priority::Normal)
{
// Build all instances first, then register under a single lock hold
// so readers never see a half-registered set.
std::vector<std::pair<GroupedConvKernelKey, std::shared_ptr<GroupedConvKernelInstance>>>
batch;
batch.reserve(kernel_set.declarations().size());
for(const auto& decl : kernel_set.declarations())
{
GroupedConvKernelKey key;
key.dtype_in = decl.signature.dtype_in_;
key.dtype_wei = decl.signature.dtype_wei_;
key.dtype_out = decl.signature.dtype_out_;
key.layout = decl.signature.layout_;
key.ndim_spatial = decl.signature.num_dims_;
key.op = (decl.signature.conv_op_ == "forward") ? GroupedConvOp::Forward
: (decl.signature.conv_op_ == "bwd_data") ? GroupedConvOp::BackwardData
: GroupedConvOp::BackwardWeight;
key.tile_m = decl.algorithm.tile_m_;
key.tile_n = decl.algorithm.tile_n_;
key.tile_k = decl.algorithm.tile_k_;
key.wave_m = decl.algorithm.wave_m_;
key.wave_n = decl.algorithm.wave_n_;
key.wave_k = decl.algorithm.wave_k_;
key.warp_m = decl.algorithm.warp_m_;
key.warp_n = decl.algorithm.warp_n_;
key.warp_k = decl.algorithm.warp_k_;
key.pipeline = decl.algorithm.pipeline_;
key.scheduler = decl.algorithm.scheduler_;
key.epilogue = decl.algorithm.epilogue_;
key.vector_size_a = decl.algorithm.vector_a_;
key.vector_size_b = decl.algorithm.vector_b_;
key.vector_size_c = decl.algorithm.vector_c_;
key.block_per_cu = decl.algorithm.block_per_cu_;
key.num_wave_groups = decl.algorithm.num_wave_groups_;
key.num_groups_to_merge = decl.algorithm.num_groups_to_merge_;
key.arch = decl.arch;
batch.emplace_back(key,
std::make_shared<GroupedConvKernelInstance>(
key, decl.name(), [](const GroupedConvProblem&, void*) -> float {
return 0.0f;
}));
}
std::lock_guard<std::mutex> lock(mutex());
bool any_registered = false;
for(auto& [key, instance] : batch)
{
auto it = entries().find(key);
if(it == entries().end() || it->second.priority <= priority)
{
entries_mut()[key] = typename Base::Entry{std::move(instance), priority};
any_registered = true;
}
}
return any_registered;
}
/// Find the best kernel for a problem
const GroupedConvKernelInstance* find(const GroupedConvProblem& problem) const
{
std::lock_guard<std::mutex> lock(mutex());
const GroupedConvKernelInstance* best = nullptr;
Priority best_priority = Priority::Low;
for(const auto& [key, entry] : entries())
{
if(entry.instance->matches(problem))
{
if(!best || entry.priority > best_priority)
{
best = entry.instance.get();
best_priority = entry.priority;
}
}
}
return best;
}
/// Get all registered kernels
std::vector<const GroupedConvKernelInstance*> all_kernels() const
{
std::lock_guard<std::mutex> lock(mutex());
std::vector<const GroupedConvKernelInstance*> result;
for(const auto& [key, entry] : entries())
{
result.push_back(entry.instance.get());
}
return result;
}
/// Export registry to JSON string
std::string export_json(bool include_statistics = false) const
{
// Note: get_name() acquires the mutex internally, so we must NOT hold
// the registry mutex here (std::mutex is not recursive).
std::string reg_name = get_name();
std::lock_guard<std::mutex> lock(mutex());
std::ostringstream json;
json << "{\n";
json << " \"metadata\": {\n";
json << " \"registry_name\": \"" << json_escape(reg_name) << "\",\n";
json << " \"total_kernels\": " << entries().size() << "\n";
json << " }";
if(include_statistics && !entries().empty())
{
std::map<std::string, int> by_datatype;
std::map<std::string, int> by_pipeline;
std::map<std::string, int> by_arch;
for(const auto& [key, entry] : entries())
{
std::string dtype_key = key.dtype_in + "_" + key.dtype_wei + "_" + key.dtype_out;
by_datatype[dtype_key]++;
by_pipeline[key.pipeline]++;
by_arch[key.arch]++;
}
json << ",\n \"statistics\": {\n";
json << " \"by_datatype\": {";
bool first = true;
for(const auto& [dtype, count] : by_datatype)
{
if(!first)
json << ",";
json << "\"" << json_escape(dtype) << "\":" << count;
first = false;
}
json << "},\n";
json << " \"by_pipeline\": {";
first = true;
for(const auto& [pipeline, count] : by_pipeline)
{
if(!first)
json << ",";
json << "\"" << json_escape(pipeline) << "\":" << count;
first = false;
}
json << "},\n";
json << " \"by_arch\": {";
first = true;
for(const auto& [arch, count] : by_arch)
{
if(!first)
json << ",";
json << "\"" << json_escape(arch) << "\":" << count;
first = false;
}
json << "}\n }";
}
json << ",\n \"kernels\": [\n";
bool first = true;
for(const auto& [key, entry] : entries())
{
if(!first)
json << ",\n";
json << " " << export_kernel_json(*entry.instance);
first = false;
}
json << "\n ]\n";
json << "}\n";
return json.str();
}
/// Export registry to JSON file
void export_json_to_file(const std::string& filename, bool include_statistics = false) const
{
std::string json_str = export_json(include_statistics);
std::ofstream file(filename);
if(!file.is_open())
{
throw std::runtime_error("Failed to open file for export: " + filename);
}
file << json_str;
}
/// Get kernels matching a predicate
std::vector<const GroupedConvKernelInstance*>
filter(std::function<bool(const GroupedConvKernelInstance&)> predicate) const
{
std::lock_guard<std::mutex> lock(mutex());
std::vector<const GroupedConvKernelInstance*> result;
for(const auto& [key, entry] : entries())
{
if(predicate(*entry.instance))
{
result.push_back(entry.instance.get());
}
}
return result;
}
/// Remove kernels not matching the arch
std::size_t filter_by_arch(const std::string& gpu_arch)
{
std::lock_guard<std::mutex> lock(mutex());
std::vector<GroupedConvKernelKey> to_remove;
for(const auto& [key, entry] : entries())
{
if(key.arch != gpu_arch)
{
to_remove.push_back(key);
}
}
for(const auto& key : to_remove)
{
entries_mut().erase(key);
}
return to_remove.size();
}
private:
static std::string json_escape(const std::string& str)
{
std::ostringstream oss;
for(char c : str)
{
switch(c)
{
case '"': oss << "\\\""; break;
case '\\': oss << "\\\\"; break;
case '\b': oss << "\\b"; break;
case '\f': oss << "\\f"; break;
case '\n': oss << "\\n"; break;
case '\r': oss << "\\r"; break;
case '\t': oss << "\\t"; break;
default:
if(c < 0x20)
{
oss << "\\u" << std::hex << std::setw(4) << std::setfill('0') << (int)c;
}
else
{
oss << c;
}
}
}
return oss.str();
}
static std::string export_kernel_json(const GroupedConvKernelInstance& kernel)
{
std::ostringstream json;
const auto& key = kernel.key();
std::string op_str;
switch(key.op)
{
case GroupedConvOp::Forward: op_str = "fwd"; break;
case GroupedConvOp::BackwardData: op_str = "bwd_data"; break;
case GroupedConvOp::BackwardWeight: op_str = "bwd_weight"; break;
}
json << "{\n";
json << " \"name\": \"" << json_escape(kernel.name()) << "\",\n";
json << " \"signature\": {\n";
json << " \"dtype_in\": \"" << json_escape(key.dtype_in) << "\",\n";
json << " \"dtype_wei\": \"" << json_escape(key.dtype_wei) << "\",\n";
json << " \"dtype_out\": \"" << json_escape(key.dtype_out) << "\",\n";
json << " \"layout\": \"" << json_escape(key.layout) << "\",\n";
json << " \"ndim_spatial\": " << key.ndim_spatial << ",\n";
json << " \"op\": \"" << op_str << "\"\n";
json << " },\n";
json << " \"algorithm\": {\n";
json << " \"tile_m\": " << key.tile_m << ",\n";
json << " \"tile_n\": " << key.tile_n << ",\n";
json << " \"tile_k\": " << key.tile_k << ",\n";
json << " \"wave\": \"" << key.wave_m << "x" << key.wave_n << "x" << key.wave_k
<< "\",\n";
json << " \"warp\": \"" << key.warp_m << "x" << key.warp_n << "x" << key.warp_k
<< "\",\n";
json << " \"pipeline\": \"" << json_escape(key.pipeline) << "\",\n";
json << " \"scheduler\": \"" << json_escape(key.scheduler) << "\",\n";
json << " \"epilogue\": \"" << json_escape(key.epilogue) << "\",\n";
json << " \"vector_sizes\": [" << key.vector_size_a << "," << key.vector_size_b
<< "," << key.vector_size_c << "],\n";
json << " \"block_per_cu\": " << key.block_per_cu << ",\n";
json << " \"num_wave_groups\": " << key.num_wave_groups << ",\n";
json << " \"num_groups_to_merge\": " << key.num_groups_to_merge << "\n";
json << " },\n";
json << " \"arch\": \"" << json_escape(key.arch) << "\"\n";
json << " }";
return json.str();
}
};
// =============================================================================
// GroupedConvDispatcher - Selects and runs the best kernel for a problem
// =============================================================================
class GroupedConvDispatcher
{
public:
enum class SelectionStrategy
{
PriorityBased,
Heuristic
};
using HeuristicFunction = std::function<std::vector<std::string>(const GroupedConvProblem&)>;
explicit GroupedConvDispatcher(GroupedConvRegistry* registry)
: registry_(registry), strategy_(SelectionStrategy::PriorityBased)
{
}
void set_strategy(SelectionStrategy s) { strategy_ = s; }
void set_heuristic(HeuristicFunction fn) { heuristic_ = std::move(fn); }
/// Select the best kernel for a problem (does not run it)
const GroupedConvKernelInstance* select_kernel(const GroupedConvProblem& problem) const
{
if(strategy_ == SelectionStrategy::Heuristic)
return select_heuristic(problem);
return registry_->find(problem);
}
/// Run convolution with automatic kernel selection (legacy - no buffers)
float run(const GroupedConvProblem& problem, void* stream = nullptr)
{
const auto* kernel = select_kernel(problem);
if(!kernel)
{
throw NoKernelFound("No suitable grouped convolution kernel found for problem: " +
problem.to_string());
}
return kernel->run(problem, stream);
}
/// Run convolution with buffer pointers and automatic kernel selection.
/// Sets the thread-local buffer context before dispatching to the kernel.
float run(const void* input_ptr,
const void* weight_ptr,
void* output_ptr,
const GroupedConvProblem& problem,
void* stream = nullptr,
int warmup = 3,
int repeat = 10)
{
const auto* kernel = select_kernel(problem);
if(!kernel)
{
throw NoKernelFound("No suitable grouped convolution kernel found for problem: " +
problem.to_string());
}
g_conv_dispatch_buffers.input_ptr = input_ptr;
g_conv_dispatch_buffers.weight_ptr = weight_ptr;
g_conv_dispatch_buffers.output_ptr = output_ptr;
g_conv_dispatch_buffers.warmup = warmup;
g_conv_dispatch_buffers.repeat = repeat;
g_conv_dispatch_buffers.benchmarking = benchmarking_;
g_conv_dispatch_buffers.split_k = problem.split_k;
return kernel->run(problem, stream);
}
/// Enable or disable GPU benchmarking (timing).
/// When disabled, kernels execute once with no timing overhead.
void set_benchmarking(bool enable) { benchmarking_ = enable; }
[[nodiscard]] bool benchmarking_enabled() const { return benchmarking_; }
/// Alias kept for backward compatibility
const GroupedConvKernelInstance* select(const GroupedConvProblem& problem) const
{
return select_kernel(problem);
}
private:
const GroupedConvKernelInstance* select_heuristic(const GroupedConvProblem& problem) const
{
if(!heuristic_)
return registry_->find(problem);
auto ranked_names = heuristic_(problem);
auto all = registry_->all_kernels();
for(const auto& name : ranked_names)
{
for(const auto* kernel : all)
{
if(kernel->name().find(name) != std::string::npos && kernel->matches(problem))
{
return kernel;
}
}
}
return registry_->find(problem);
}
GroupedConvRegistry* registry_;
SelectionStrategy strategy_;
HeuristicFunction heuristic_;
bool benchmarking_ = true;
};
} // namespace dispatcher
} // namespace ck_tile

View File

@@ -0,0 +1,324 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* @file grouped_conv_utils.hpp
* @brief CK Tile Grouped Convolution Dispatcher Utilities
*/
#pragma once
#include "ck_tile/dispatcher/grouped_conv_config.hpp"
#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp"
#include "ck_tile/dispatcher/grouped_conv_problem.hpp"
#include "ck_tile/dispatcher/grouped_conv_registry.hpp"
#include "ck_tile/dispatcher/arch_filter.hpp"
#include "ck_tile/dispatcher/utils.hpp"
#include <iostream>
#include <iomanip>
#include <memory>
#include <vector>
#include <string>
#include <sstream>
#include <functional>
#include <cmath>
#include <algorithm>
namespace ck_tile {
namespace dispatcher {
using GroupedConvSig = grouped_conv_decl::GroupedConvSignature;
using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm;
namespace grouped_conv_utils {
inline GroupedConvKernelDecl create_grouped_conv2d_fwd(const std::string& dtype = "fp16",
int tile_n = 128,
int tile_k = 128,
const std::string& arch = "gfx942")
{
return GroupedConvKernelDecl(
GroupedConvSig().dtype(dtype).layout("nhwc").conv_type("forward").dims(2),
GroupedConvAlgo()
.tile(1, tile_n, tile_k)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv4")
.vector_sizes(4, 8, 8),
arch);
}
inline GroupedConvKernelDecl create_grouped_conv3d_fwd(const std::string& dtype = "fp16",
int tile_n = 64,
int tile_k = 64,
const std::string& arch = "gfx942")
{
return GroupedConvKernelDecl(
GroupedConvSig().dtype(dtype).layout("ndhwc").conv_type("forward").dims(3),
GroupedConvAlgo()
.tile(1, tile_n, tile_k)
.wave(2, 2, 1)
.warp(16, 16, 32)
.pipeline("compv3")
.vector_sizes(4, 8, 8),
arch);
}
inline GroupedConvKernelDecl create_grouped_conv2d_bwd_data(const std::string& dtype = "fp16",
int tile_n = 128,
int tile_k = 128,
const std::string& arch = "gfx942")
{
return GroupedConvKernelDecl(
GroupedConvSig().dtype(dtype).layout("nhwc").conv_type("bwd_data").dims(2),
GroupedConvAlgo()
.tile(1, tile_n, tile_k)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.vector_sizes(4, 8, 8),
arch);
}
inline GroupedConvKernelDecl create_grouped_conv2d_bwd_weight(const std::string& dtype = "fp16",
int tile_n = 128,
int tile_k = 128,
const std::string& arch = "gfx942")
{
return GroupedConvKernelDecl(
GroupedConvSig().dtype(dtype).layout("nhwc").conv_type("bwd_weight").dims(2),
GroupedConvAlgo()
.tile(1, tile_n, tile_k)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.memory_op("atomic_add")
.vector_sizes(4, 8, 8),
arch);
}
inline GroupedConvProblem create_grouped_conv2d_problem(int N,
int C,
int K,
int Hi,
int Wi,
int Y,
int X,
int stride = 1,
int padding = 0,
GroupedConvOp op = GroupedConvOp::Forward)
{
GroupedConvProblem p;
p.N = N;
p.C = C;
p.K = K;
p.G = 1;
p.input_spatial = {1, Hi, Wi};
p.filter_spatial = {1, Y, X};
p.stride = {1, stride, stride};
p.padding = {0, padding, padding};
p.dilation = {1, 1, 1};
p.op = op;
p.compute_output_size();
return p;
}
inline GroupedConvProblem create_grouped_conv3d_problem(int N,
int C,
int K,
int Di,
int Hi,
int Wi,
int Z,
int Y,
int X,
int stride = 1,
int padding = 0,
GroupedConvOp op = GroupedConvOp::Forward)
{
GroupedConvProblem p;
p.N = N;
p.C = C;
p.K = K;
p.G = 1;
p.input_spatial = {Di, Hi, Wi};
p.filter_spatial = {Z, Y, X};
p.stride = {stride, stride, stride};
p.padding = {padding, padding, padding};
p.dilation = {1, 1, 1};
p.op = op;
p.compute_output_size();
return p;
}
inline GroupedConvProblem create_depthwise_grouped_conv2d_problem(
int N, int C, int Hi, int Wi, int Y, int X, int stride = 1, int padding = 0)
{
GroupedConvProblem p;
p.N = N;
p.C = C;
p.K = C;
p.G = C;
p.input_spatial = {1, Hi, Wi};
p.filter_spatial = {1, Y, X};
p.stride = {1, stride, stride};
p.padding = {0, padding, padding};
p.dilation = {1, 1, 1};
p.op = GroupedConvOp::Forward;
p.compute_output_size();
return p;
}
inline void print_pattern_docs(std::ostream& os = std::cout)
{
os << "Grouped Convolution Pattern Documentation\n";
os << "==========================================\n";
os << "Signature patterns: dtype, layout, conv_type (forward/bwd_data/bwd_weight), dims "
"(2/3)\n";
os << "Algorithm patterns: tile(M,N,K), wave(M,N,K), warp(M,N,K), pipeline, vector_sizes\n";
os << "Arch patterns: gfx942, gfx90a, gfx950, or '*' for all\n";
}
inline void print_grouped_conv_kernel_decl(const GroupedConvKernelDecl& decl,
std::ostream& os = std::cout)
{
os << "GroupedConvKernelDecl: " << decl.name() << "\n";
os << " Signature: dtype=" << decl.signature.dtype_in_ << ", layout=" << decl.signature.layout_
<< ", conv_type=" << decl.signature.conv_op_ << ", dims=" << decl.signature.num_dims_
<< "\n";
os << " Algorithm: tile=" << decl.algorithm.tile_m_ << "x" << decl.algorithm.tile_n_ << "x"
<< decl.algorithm.tile_k_ << ", wave=" << decl.algorithm.wave_m_ << "x"
<< decl.algorithm.wave_n_ << "x" << decl.algorithm.wave_k_
<< ", warp=" << decl.algorithm.warp_m_ << "x" << decl.algorithm.warp_n_ << "x"
<< decl.algorithm.warp_k_ << ", pipeline=" << decl.algorithm.pipeline_ << "\n";
os << " Arch: " << decl.arch << "\n";
}
inline void print_grouped_conv_problem(const GroupedConvProblem& p, std::ostream& os = std::cout)
{
os << p.to_string() << "\n";
os << " FLOPs: " << std::scientific << p.get_flops() << "\n";
}
inline GroupedConvKernelSet build_grouped_conv2d_fwd_set(const std::string& dtype = "fp16",
const std::string& arch = "gfx942")
{
GroupedConvKernelSet set;
auto decl1 = create_grouped_conv2d_fwd(dtype, 128, 128, arch);
set.add(decl1.signature, decl1.algorithm, decl1.arch);
auto decl2 = create_grouped_conv2d_fwd(dtype, 256, 256, arch);
set.add(decl2.signature, decl2.algorithm, decl2.arch);
return set;
}
inline GroupedConvKernelSet build_grouped_conv2d_full_set(const std::string& dtype = "fp16",
const std::string& arch = "gfx942")
{
GroupedConvKernelSet set;
set.merge(build_grouped_conv2d_fwd_set(dtype, arch));
auto bwd_data = create_grouped_conv2d_bwd_data(dtype, 128, 128, arch);
set.add(bwd_data.signature, bwd_data.algorithm, bwd_data.arch);
auto bwd_weight = create_grouped_conv2d_bwd_weight(dtype, 128, 128, arch);
set.add(bwd_weight.signature, bwd_weight.algorithm, bwd_weight.arch);
return set;
}
struct ValidationResult
{
bool passed = false;
float max_abs_diff = 0.0f;
float max_rel_diff = 0.0f;
float rtol = 1e-3f;
float atol = 1e-3f;
void print(std::ostream& os = std::cout) const
{
os << "ValidationResult: " << (passed ? "PASSED" : "FAILED") << "\n";
os << " max_abs_diff: " << max_abs_diff << ", max_rel_diff: " << max_rel_diff << "\n";
os << " rtol: " << rtol << ", atol: " << atol << "\n";
}
};
template <typename T>
inline ValidationResult validate_buffers(
const T* result, const T* reference, size_t count, float rtol = 1e-3f, float atol = 1e-3f)
{
ValidationResult vr;
vr.rtol = rtol;
vr.atol = atol;
vr.passed = true;
for(size_t i = 0; i < count; ++i)
{
float r = static_cast<float>(result[i]);
float ref = static_cast<float>(reference[i]);
float abs_diff = std::abs(r - ref);
float rel_diff = (std::abs(ref) > 1e-10f) ? (abs_diff / std::abs(ref)) : 0.0f;
vr.max_abs_diff = std::max(vr.max_abs_diff, abs_diff);
vr.max_rel_diff = std::max(vr.max_rel_diff, rel_diff);
float threshold = atol + rtol * std::abs(ref);
if(abs_diff > threshold)
{
vr.passed = false;
}
}
return vr;
}
struct BenchmarkResult
{
std::string kernel_name;
float time_ms = 0.0f;
float tflops = 0.0f;
int warmup_runs = 0;
int benchmark_runs = 0;
void print(std::ostream& os = std::cout) const
{
os << "BenchmarkResult: " << kernel_name << "\n";
os << " time_ms: " << time_ms << ", tflops: " << tflops << "\n";
os << " warmup_runs: " << warmup_runs << ", benchmark_runs: " << benchmark_runs << "\n";
}
};
inline float calc_tflops(double flops, float time_ms)
{
return static_cast<float>(flops / (time_ms * 1e9));
}
inline double calculate_conv_tflops(const GroupedConvProblem& problem, double time_ms)
{
return problem.get_flops() / (time_ms * 1e9);
}
} // namespace grouped_conv_utils
namespace examples {
inline int basic_grouped_conv_example_main(const std::string& example_name)
{
std::cout << "=== " << example_name << " ===\n";
// Create a grouped convolution problem
auto problem = grouped_conv_utils::create_grouped_conv2d_problem(
32, 64, 128, 28, 28, 3, 3, 1, 1, GroupedConvOp::Forward);
grouped_conv_utils::print_grouped_conv_problem(problem);
// Create and print a kernel declaration
auto decl = grouped_conv_utils::create_grouped_conv2d_fwd("fp16", 128, 128, "gfx942");
grouped_conv_utils::print_grouped_conv_kernel_decl(decl);
// Build and print kernel set
auto kernel_set = grouped_conv_utils::build_grouped_conv2d_fwd_set("fp16", "gfx942");
kernel_set.print();
return 0;
}
} // namespace examples
} // namespace dispatcher
} // namespace ck_tile

View File

@@ -98,7 +98,7 @@ struct Problem
/**
* Create Problem by inferring MNK from tensor shapes.
*
* For GEMM: C[M,N] = A[M,K] × B[K,N]
* For GEMM: C[M,N] = A[M,K] x B[K,N]
*
* @param a_shape Shape of matrix A (M x K, or K x M if transposed)
* @param b_shape Shape of matrix B (K x N, or N x K if transposed)
@@ -113,7 +113,7 @@ struct Problem
[[nodiscard]] static Problem
from_shapes(TensorShape a_shape, TensorShape b_shape, TensorShape c_shape)
{
// For C = A × B:
// For C = A x B:
// A: [M, K] (or [K, M] if transposed)
// B: [K, N] (or [N, K] if transposed)
// C: [M, N]
@@ -164,7 +164,7 @@ struct Problem
* @throws std::invalid_argument if dimensions are inconsistent
*
* Example:
* // A[512,256] × B[256,1024] = C[512,1024]
* // A[512,256] x B[256,1024] = C[512,1024]
* auto problem = Problem::from_dimensions(512, 256, 256, 1024, 512, 1024);
*/
[[nodiscard]] static Problem from_dimensions(std::int64_t a_rows,
@@ -188,7 +188,7 @@ struct Problem
* @throws std::invalid_argument if K dimensions don't match
*
* Example:
* // A[512,256] × B[256,1024] = C[512,1024]
* // A[512,256] x B[256,1024] = C[512,1024]
* auto problem = Problem::from_ab(512, 256, 256, 1024);
*/
[[nodiscard]] static Problem

View File

@@ -7,38 +7,20 @@
* Central registry for all available kernel instances with priority-based
* ordering and efficient lookup.
*
* Features:
* - Thread-safe registration and lookup
* - Priority-based ordering (High, Normal, Low)
* - Lookup by name or KernelKey
* - Filter by problem compatibility
* - Supports both singleton and multiple instance patterns
*
* Usage (Singleton - backward compatible):
* auto& registry = Registry::instance();
* registry.register_kernel(kernel, Priority::High);
* auto kernel = registry.lookup("kernel_name");
*
* Usage (Multiple registries):
* Registry fp16_registry;
* Registry bf16_registry;
* fp16_registry.register_kernel(fp16_kernel, Priority::High);
* bf16_registry.register_kernel(bf16_kernel, Priority::High);
*
* Dispatcher fp16_dispatcher(&fp16_registry);
* Dispatcher bf16_dispatcher(&bf16_registry);
* Derives from BaseRegistry for shared logic (thread safety, naming, priority,
* merge) while keeping GEMM-specific APIs (lookup by KernelKey, filter_by_arch,
* JSON export, auto-export).
*
* Status: Production ready, thread-safe
*/
#pragma once
#include "ck_tile/dispatcher/base_registry.hpp"
#include "ck_tile/dispatcher/kernel_instance.hpp"
#include "ck_tile/dispatcher/kernel_key.hpp"
#include <functional>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include <memory>
@@ -47,20 +29,16 @@ namespace dispatcher {
/// Registry: Central mapping from kernel configurations to executable instances
/// Thread-safe kernel registration and lookup
/// Supports both singleton pattern and multiple independent instances
class Registry
/// Derives from BaseRegistry<Registry, std::string, KernelInstance> for shared functionality
class Registry : public BaseRegistry<Registry, std::string, KernelInstance>
{
using Base = BaseRegistry<Registry, std::string, KernelInstance>;
public:
/// Priority levels for conflict resolution when multiple kernels have same key
enum class Priority
{
Low = 0,
Normal = 1,
High = 2
};
// Re-export Priority from the shared enum for backward compatibility
using Priority = ck_tile::dispatcher::Priority;
/// Default constructor - creates an empty registry instance
/// Use this to create independent registries for different kernel sets
Registry();
/// Destructor - triggers auto-export if enabled
@@ -72,106 +50,51 @@ class Registry
/// Move assignment
Registry& operator=(Registry&& other) noexcept;
// Prevent copying (registries contain shared_ptrs that shouldn't be duplicated)
// Prevent copying
Registry(const Registry&) = delete;
Registry& operator=(const Registry&) = delete;
/// Register a kernel instance with the registry
/// @param instance Kernel instance to register
/// @param priority Priority level for conflict resolution (default: Normal)
/// @return true if registered successfully, false if duplicate with higher priority exists
bool register_kernel(KernelInstancePtr instance, Priority priority = Priority::Normal);
/// Lookup a kernel by its string identifier
/// @param identifier Kernel identifier string
/// @return Kernel instance if found, nullptr otherwise
[[nodiscard]] KernelInstancePtr lookup(const std::string& identifier) const;
/// Lookup a kernel by its KernelKey
/// @param key Kernel configuration key
/// @return Kernel instance if found, nullptr otherwise
[[nodiscard]] KernelInstancePtr lookup(const KernelKey& key) const;
/// Get all registered kernels
/// @return Vector of all kernel instances
[[nodiscard]] std::vector<KernelInstancePtr> get_all() const;
/// Get all kernels matching a predicate
/// @param predicate Function to filter kernels
/// @return Vector of matching kernel instances
[[nodiscard]] std::vector<KernelInstancePtr>
filter(std::function<bool(const KernelInstance&)> predicate) const;
/// Get number of registered kernels
[[nodiscard]] std::size_t size() const;
/// Check if registry is empty
[[nodiscard]] bool empty() const;
/// Clear all registered kernels
void clear();
/// Get registry name (for logging/debugging)
[[nodiscard]] const std::string& get_name() const;
/// Set registry name (for logging/debugging)
void set_name(const std::string& name);
// size(), empty(), clear(), get_name(), set_name(), merge_from() inherited from Base
/// Export registry to JSON string
/// @param include_statistics Whether to include kernel statistics breakdown
/// @return JSON string with all kernel metadata
[[nodiscard]] std::string export_json(bool include_statistics = true) const;
/// Export registry to JSON file
/// @param filename Output filename
/// @param include_statistics Whether to include kernel statistics breakdown
/// @return true if export succeeded, false otherwise
bool export_json_to_file(const std::string& filename, bool include_statistics = true) const;
/// Enable automatic JSON export on kernel registration
/// @param filename Output filename for auto-export
/// @param include_statistics Whether to include statistics in auto-export
/// @param export_on_every_registration If true, exports after every registration (default).
/// If false, only exports on destruction.
void enable_auto_export(const std::string& filename,
bool include_statistics = true,
bool export_on_every_registration = true);
/// Disable automatic JSON export
void disable_auto_export();
/// Check if auto-export is enabled
[[nodiscard]] bool is_auto_export_enabled() const;
/// Merge kernels from another registry into this one
/// @param other Registry to merge from
/// @param priority Priority for merged kernels (default: Normal)
/// @return Number of kernels successfully merged
std::size_t merge_from(const Registry& other, Priority priority = Priority::Normal);
/// Filter kernels in-place by architecture
/// @param gpu_arch Target GPU architecture string (e.g., "gfx942")
/// @return Number of kernels removed
std::size_t filter_by_arch(const std::string& gpu_arch);
/// Get singleton instance of the global registry (backward compatible)
/// This is the default registry used when no specific registry is provided
/// Get singleton instance
static Registry& instance();
private:
struct RegistryEntry
{
KernelInstancePtr instance;
Priority priority;
};
/// Perform auto-export if enabled
void perform_auto_export();
mutable std::mutex mutex_;
std::unordered_map<std::string, RegistryEntry> kernels_;
std::string name_;
// Auto-export configuration
bool auto_export_enabled_ = false;
std::string auto_export_filename_;
@@ -179,7 +102,7 @@ class Registry
bool auto_export_on_every_registration_ = true;
};
/// Shared pointer type for registries (useful for managing lifetime)
/// Shared pointer type for registries
using RegistryPtr = std::shared_ptr<Registry>;
/// Create a new registry instance (factory function)

View File

@@ -0,0 +1,18 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/// Grouped Convolution-only dispatcher header -- minimal include for conv operations.
#pragma once
// Core (needed by all ops)
#include "ck_tile/dispatcher/base_registry.hpp"
#include "ck_tile/dispatcher/dispatcher_error.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
// Grouped Convolution
#include "ck_tile/dispatcher/grouped_conv_config.hpp"
#include "ck_tile/dispatcher/grouped_conv_problem.hpp"
#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp"
#include "ck_tile/dispatcher/grouped_conv_registry.hpp"
#include "ck_tile/dispatcher/grouped_conv_utils.hpp"

View File

@@ -0,0 +1,22 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/// GEMM-only dispatcher header -- minimal include for GEMM operations.
#pragma once
// Core (needed by all ops)
#include "ck_tile/dispatcher/base_registry.hpp"
#include "ck_tile/dispatcher/dispatcher_error.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
// GEMM
#include "ck_tile/dispatcher/kernel_key.hpp"
#include "ck_tile/dispatcher/kernel_config.hpp"
#include "ck_tile/dispatcher/kernel_decl.hpp"
#include "ck_tile/dispatcher/kernel_instance.hpp"
#include "ck_tile/dispatcher/problem.hpp"
#include "ck_tile/dispatcher/registry.hpp"
#include "ck_tile/dispatcher/dispatcher.hpp"
#include "ck_tile/dispatcher/json_export.hpp"
#include "ck_tile/dispatcher/utils.hpp"

View File

@@ -3,7 +3,7 @@
# This directory contains Python utilities for the dispatcher examples.
# The main utility file is ctypes_utils.py which is used by GEMM Python examples.
# Conv Python examples use their own conv_utils.py in the examples directory.
# Grouped conv Python examples use grouped_conv_utils.py in this directory.
# No build targets needed - these are pure Python utilities.
message(STATUS "Python utilities directory configured (no build targets)")

View File

@@ -4,6 +4,19 @@ This directory contains Python utilities used by the dispatcher examples.
## Contents
### Shared Utilities (used by both GEMM and Grouped Conv)
- `dispatcher_common.py` - Shared dispatcher infrastructure
- Path helpers (`get_dispatcher_root`, `get_build_dir`, etc.)
- `ValidationResultBase` - Structured validation feedback
- `validate_wave_config`, `validate_warp_tile_config`, `validate_trait_combo`
- `auto_correct_wave`, `auto_correct_trait` - Auto-correction helpers
- `Colors` - Cross-platform ANSI color support
- `print_phase`, `print_success`, `print_error`, `print_info` - Phased output
- `cleanup_generated_kernels` - Cleanup helper
### GEMM Utilities
- `ctypes_utils.py` - Core ctypes utilities for GEMM Python examples
- `KernelConfig` - Kernel configuration dataclass
- `setup_gemm_dispatcher()` - Setup dispatcher with auto-correction
@@ -11,11 +24,15 @@ This directory contains Python utilities used by the dispatcher examples.
- `GemmRunner` - GPU execution helper
- Auto-correction and validation utilities
- `conv_utils.py` - Core utilities for Conv Python examples
- `ConvSignature`, `ConvAlgorithm` - Convolution configuration
- `ConvProblem` - Problem definition
- `GpuConvRunner` - GPU execution helper
- `EnhancedConvCodegenRunner` - Kernel codegen utilities
### Grouped Convolution Utilities
- `grouped_conv_utils.py` - Utilities for grouped convolution
- `GroupedConvValidationResult` - Validation result (extends `ValidationResultBase`)
- `validate_grouped_conv_config` - Validate a grouped conv config
- `auto_correct_grouped_conv_config` - Auto-correct invalid configs
- `get_grouped_conv_default_config` - Get default config for a variant
- `GroupedConvDataType` - Data type enum (FP16, BF16, FP32, FP8, BF8, INT8)
- `format_grouped_conv_summary` - Human-readable config summary
## Usage
@@ -36,21 +53,26 @@ from ctypes_utils import (
)
```
### Conv Examples
The Conv Python examples in `dispatcher/examples/conv/python/` import:
### Grouped Conv Usage
```python
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
from conv_utils import (
ConvSignature,
ConvAlgorithm,
ConvProblem,
GpuConvRunner,
from grouped_conv_utils import (
validate_grouped_conv_config,
auto_correct_grouped_conv_config,
get_grouped_conv_default_config,
GroupedConvDataType,
)
# Get a default config
config = get_grouped_conv_default_config(variant="forward", arch="gfx942")
# Validate
result = validate_grouped_conv_config(config)
print(f"Valid: {result.is_valid}")
```
## Requirements

View File

@@ -37,6 +37,43 @@ import multiprocessing
import time
# =============================================================================
# GPU Architecture Auto-Detection
# =============================================================================
_detected_arch: Optional[str] = None
def detect_gpu_arch(fallback: str = "gfx942") -> str:
"""
Auto-detect the GPU architecture by querying rocminfo.
Caches the result after the first call. Falls back to `fallback` if
detection fails (e.g. no GPU, rocminfo not installed).
"""
global _detected_arch
if _detected_arch is not None:
return _detected_arch
try:
result = subprocess.run(
["/opt/rocm/bin/rocminfo"], capture_output=True, text=True, timeout=10
)
for line in result.stdout.splitlines():
stripped = line.strip()
if stripped.startswith("Name:") and "gfx" in stripped:
# Extract e.g. "gfx950" from "Name: gfx950"
name = stripped.split(":", 1)[1].strip()
if name.startswith("gfx") and name[3:].isdigit():
_detected_arch = name
return _detected_arch
except Exception:
pass
_detected_arch = fallback
return _detected_arch
# =============================================================================
# Path Configuration
# =============================================================================
@@ -159,9 +196,9 @@ class ValidationResult:
def print_result(self, indent: str = " "):
"""Print validation result."""
if self.is_valid:
print(f"{indent} Configuration valid")
print(f"{indent}OK Configuration valid")
else:
print(f"{indent} Configuration has issues:")
print(f"{indent}WARNING Configuration has issues:")
for err in self.errors:
print(f"{indent} - {err}")
@@ -300,7 +337,7 @@ def auto_correct_kernel_config(
# Check each fix and describe what changed
if "scheduler" in fixes and fixes["scheduler"] != config.scheduler:
corrections.append(
f"Scheduler: {config.scheduler} {fixes['scheduler']} "
f"Scheduler: {config.scheduler} -> {fixes['scheduler']} "
f"('{config.scheduler}' not supported with pipeline={config.pipeline}, epilogue={config.epilogue})"
)
@@ -309,7 +346,7 @@ def auto_correct_kernel_config(
new_wave = f"[{fixes.get('wave_m', config.wave_m)}, {fixes.get('wave_n', config.wave_n)}, {fixes.get('wave_k', config.wave_k)}]"
if old_wave != new_wave:
corrections.append(
f"Wave config: {old_wave} {new_wave} "
f"Wave config: {old_wave} -> {new_wave} "
f"(original not supported on {config.gfx_arch})"
)
@@ -318,7 +355,7 @@ def auto_correct_kernel_config(
new_warp = f"[{fixes.get('warp_m', config.warp_m)}, {fixes.get('warp_n', config.warp_n)}, {fixes.get('warp_k', config.warp_k)}]"
if old_warp != new_warp:
corrections.append(
f"Warp tile: {old_warp} {new_warp} "
f"Warp tile: {old_warp} -> {new_warp} "
f"(original not supported for {config.dtype_a} on {config.gfx_arch})"
)
@@ -386,13 +423,13 @@ def print_auto_correction(
indent: Indentation for output
"""
if not corrections:
print(f"{indent} Configuration valid - no corrections needed")
print(f"{indent}OK Configuration valid - no corrections needed")
return
print(f"\n{indent} AUTO-CORRECTION APPLIED:")
print(f"\n{indent}WARNING AUTO-CORRECTION APPLIED:")
print(f"{indent}" + "-" * 50)
for correction in corrections:
print(f"{indent} {correction}")
print(f"{indent} - {correction}")
print(f"{indent}" + "-" * 50)
print()
@@ -976,6 +1013,226 @@ def _run_codegen_subprocess(args: Dict[str, Any]) -> CodegenResult:
)
def _run_hipcc_subprocess(args: dict) -> Tuple[bool, Optional[Path], str]:
"""Module-level function to run hipcc compilation in parallel."""
import subprocess
from pathlib import Path
compile_cmd = args["compile_cmd"]
link_cmd = args["link_cmd"]
lib_path = Path(args["lib_path"])
try:
res_c = subprocess.run(compile_cmd, capture_output=True, text=True, timeout=300)
if res_c.returncode != 0:
return False, None, f"Compile failed: {res_c.stderr[:200]}"
res_l = subprocess.run(link_cmd, capture_output=True, text=True, timeout=300)
if res_l.returncode != 0:
return False, None, f"Link failed: {res_l.stderr[:200]}"
return True, lib_path, ""
except subprocess.TimeoutExpired:
return False, None, "Timeout"
except Exception as e:
return False, None, str(e)
def _generate_single_kernel_subprocess(args: dict) -> Tuple[bool, Optional[str], str]:
"""Module-level function: generate ONE kernel .hpp via --config JSON file.
Used by setup_multiple_gemm_dispatchers for per-config parallel codegen.
Returns (success, header_path_or_None, error_msg).
"""
import subprocess
import json
import tempfile
import os
from pathlib import Path
try:
out_dir = Path(args["output_dir"])
out_dir.mkdir(parents=True, exist_ok=True)
# Write the single-config JSON to a temp file
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(args["tile_config_json"], f)
config_file = f.name
cmd = [
args["python"],
str(args["codegen_script"]),
"--output-dir",
str(out_dir),
"--datatype",
args["dtype"],
"--layout",
args["layout"],
"--gpu-target",
args["gpu_target"],
"--config",
config_file,
"--variants",
"standard",
]
res = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
os.unlink(config_file)
if res.returncode != 0:
return False, None, f"Codegen failed: {res.stderr[:200]}"
# Find the generated .hpp using the expected name pattern
pattern = args["hpp_glob_pattern"]
matches = sorted(out_dir.glob(pattern))
if matches:
return True, str(matches[0]), ""
else:
return False, None, f"No .hpp matching {pattern} after codegen"
except Exception as e:
return False, None, str(e)
def _parse_triplet(text: str) -> Optional[Tuple[int, int, int]]:
parts = text.split("x")
if len(parts) != 3:
return None
try:
return (int(parts[0]), int(parts[1]), int(parts[2]))
except ValueError:
return None
def _parse_gemm_header_metadata(header: Path) -> Optional[Dict[str, Any]]:
"""
Parse GEMM header name into configuration metadata.
Expected stem format:
gemm_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}
_{pad_m}_{pad_n}_{pad_k}_{persistent}
_{tile_m}x{tile_n}x{tile_k}_{wave_m}x{wave_n}x{wave_k}_{warp_m}x{warp_n}x{warp_k}
"""
parts = header.stem.split("_")
if len(parts) < 13 or parts[0] != "gemm":
return None
tile = _parse_triplet(parts[10])
wave = _parse_triplet(parts[11])
warp = _parse_triplet(parts[12])
if tile is None or wave is None or warp is None:
return None
def _as_bool(v: str) -> bool:
return v.lower() == "true"
return {
"dtype": parts[1],
"layout": parts[2],
"pipeline": parts[3],
"epilogue": parts[4],
"scheduler": parts[5],
"pad_m": _as_bool(parts[6]),
"pad_n": _as_bool(parts[7]),
"pad_k": _as_bool(parts[8]),
"persistent": _as_bool(parts[9]),
"tile": tile,
"wave": wave,
"warp": warp,
}
def _generate_arch_valid_gemm_headers(
python_exe: str,
codegen_script: Path,
output_dir: Path,
dtype: str,
layout: str,
gpu_target: str,
variant: str = "standard",
) -> Tuple[bool, List[Path], str]:
"""Generate (or reuse) an arch-filtered kernel catalog for fallback selection."""
output_dir.mkdir(parents=True, exist_ok=True)
pattern = f"gemm_{dtype}_{layout}_*.hpp"
existing = sorted(output_dir.glob(pattern))
if existing:
return True, existing, ""
cmd = [
python_exe,
str(codegen_script),
"--output-dir",
str(output_dir),
"--datatype",
dtype,
"--layout",
layout,
"--gpu-target",
gpu_target,
"--variants",
variant,
]
res = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
if res.returncode != 0:
err = (res.stderr or res.stdout or "").strip()[:500]
return False, [], f"Catalog codegen failed: {err}"
generated = sorted(output_dir.glob(pattern))
if not generated:
return False, [], "Catalog codegen produced no GEMM headers"
return True, generated, ""
def _select_best_arch_valid_gemm_header(
config: "KernelConfig",
headers: List[Path],
) -> Tuple[Optional[Path], Optional[Dict[str, Any]]]:
"""Choose nearest arch-valid header for a requested GEMM config."""
best: Optional[Path] = None
best_meta: Optional[Dict[str, Any]] = None
best_score: Optional[Tuple[int, int, int, int, int, int]] = None
for h in headers:
meta = _parse_gemm_header_metadata(h)
if meta is None:
continue
if meta["dtype"] != config.dtype_a or meta["layout"] != config.layout:
continue
tile = meta["tile"]
wave = meta["wave"]
warp = meta["warp"]
tile_delta = (
abs(tile[0] - config.tile_m)
+ abs(tile[1] - config.tile_n)
+ abs(tile[2] - config.tile_k)
)
wave_delta = (
abs(wave[0] - config.wave_m)
+ abs(wave[1] - config.wave_n)
+ abs(wave[2] - config.wave_k)
)
warp_delta = (
abs(warp[0] - config.warp_m)
+ abs(warp[1] - config.warp_n)
+ abs(warp[2] - config.warp_k)
)
score = (
0 if meta["pipeline"] == config.pipeline else 1,
0 if meta["scheduler"] == config.scheduler else 1,
0 if meta["epilogue"] == config.epilogue else 1,
tile_delta,
wave_delta,
warp_delta,
)
if best_score is None or score < best_score:
best_score = score
best = h
best_meta = meta
return best, best_meta
# =============================================================================
# Preshuffle Utilities
# =============================================================================
@@ -1319,7 +1576,7 @@ class CodegenRunner:
result = future.result()
results.append(result)
if verbose:
status = "" if result.success else ""
status = "OK" if result.success else "FAIL"
print(
f" {status} {variant}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s"
)
@@ -1337,7 +1594,7 @@ class CodegenRunner:
)
)
if verbose:
print(f" {variant}: FAILED - {e}")
print(f" FAIL {variant}: FAILED - {e}")
total_time = time.time() - start_total
if verbose:
@@ -1399,7 +1656,7 @@ class CodegenRunner:
result = future.result()
results.append(result)
if verbose:
status = "" if result.success else ""
status = "OK" if result.success else "FAIL"
print(
f" {status} {tile_str}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s"
)
@@ -1417,7 +1674,7 @@ class CodegenRunner:
)
)
if verbose:
print(f" {tile_str}: FAILED - {e}")
print(f" FAIL {tile_str}: FAILED - {e}")
total_time = time.time() - start_total
if verbose:
@@ -1481,7 +1738,7 @@ class CodegenRunner:
result = future.result()
results.append(result)
if verbose:
status = "" if result.success else ""
status = "OK" if result.success else "FAIL"
print(
f" {status} {variant}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s"
)
@@ -1499,7 +1756,7 @@ class CodegenRunner:
)
)
if verbose:
print(f" {variant}: FAILED - {e}")
print(f" FAIL {variant}: FAILED - {e}")
total_time = time.time() - start_total
if verbose:
@@ -1767,7 +2024,7 @@ class CodegenRunner:
link_cmd, capture_output=True, text=True, timeout=300
)
if result.returncode == 0:
print(f" Library rebuilt: {lib_path.name}")
print(f" OK Library rebuilt: {lib_path.name}")
# Clean up object file
obj_file.unlink(missing_ok=True)
return lib_path
@@ -1781,6 +2038,105 @@ class CodegenRunner:
print(f" Build error: {e}")
return None
def build_libraries_parallel(
self, configs_and_headers: List[Tuple[KernelConfig, Path]], verbose: bool = True
) -> List[Optional[Path]]:
"""
Build multiple libraries in parallel using ProcessPoolExecutor.
Returns a list of library paths (or None if a build failed) in the same order.
"""
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
start_time = time.time()
build_dir = get_build_dir()
root = get_dispatcher_root()
ck_root = root.parent
ctypes_source = root / "bindings/ctypes/gemm_ctypes_lib.cpp"
static_lib = build_dir / "libck_tile_dispatcher.a"
if not ctypes_source.exists() or not static_lib.exists():
if verbose:
print(" Required source or static library missing for parallel build.")
return [None] * len(configs_and_headers)
args_list = []
for config, kernel_header in configs_and_headers:
lib_name = f"libdispatcher_gemm_{config.dtype_a}_{config.layout}_{config.tile_str}_{config.pipeline}.so"
lib_path = build_dir / "examples" / lib_name
obj_file = lib_path.with_suffix(".o")
compile_cmd = [
"/opt/rocm/bin/hipcc",
"-c",
"-fPIC",
"-O3",
f"-I{root / 'include'}",
f"-I{ck_root / 'include'}",
f"-I{ck_root}",
f"-I{root / 'build/generated_kernels'}",
"-DCK_TILE_SINGLE_KERNEL_INCLUDE",
f"-include{kernel_header}",
"-D__HIP_PLATFORM_AMD__",
f"--offload-arch={config.gfx_arch}",
f'-DGFX_ARCH="{config.gfx_arch}"',
"-mllvm",
"-enable-noalias-to-md-conversion=0",
"-Wno-undefined-func-template",
"-Wno-float-equal",
str(ctypes_source),
"-o",
str(obj_file),
]
link_cmd = [
"/opt/rocm/bin/hipcc",
"-shared",
"-fPIC",
f"--offload-arch={config.gfx_arch}",
"--hip-link",
str(obj_file),
str(static_lib),
"-o",
str(lib_path),
]
args_list.append(
{
"compile_cmd": compile_cmd,
"link_cmd": link_cmd,
"lib_path": str(lib_path),
"config_name": f"{config.dtype_a}_{config.layout}_{config.tile_str}",
}
)
if verbose:
print(
f"Building {len(args_list)} libraries in parallel (workers={self.max_workers})..."
)
results_map = {}
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
futures = {
executor.submit(_run_hipcc_subprocess, args): i
for i, args in enumerate(args_list)
}
for future in as_completed(futures):
idx = futures[future]
success, lib_path, err = future.result()
results_map[idx] = Path(lib_path) if success else None
if verbose:
status = "OK" if success else f"FAIL ({err})"
print(
f" {status} {Path(lib_path).name if success else args_list[idx]['config_name']}"
)
if verbose:
elapsed = time.time() - start_time
print(f"Parallel build finished in {elapsed:.2f}s")
return [results_map[i] for i in range(len(configs_and_headers))]
def generate_preselected(
self, preset: str = "fp16_rcr_essential", output_dir: Optional[Path] = None
) -> CodegenResult:
@@ -1933,6 +2289,28 @@ class Registry:
"""Bind to a loaded dispatcher library."""
self._lib = lib
def build(
self,
verbose: bool = False,
max_workers: Optional[int] = None,
) -> List["GemmSetupResult"]:
"""Parallel JIT compile all kernels in this registry.
Args:
verbose: Print progress during build.
max_workers: Max parallel codegen/compile processes (default: cpu_count capped at 8).
Returns a GemmSetupResult per registered kernel (same order as get_kernels()).
"""
if not self._kernels:
return []
return setup_multiple_gemm_dispatchers(
self._kernels,
registry_name=self._name,
verbose=verbose,
max_workers=max_workers,
)
def __repr__(self) -> str:
return f"Registry(name='{self._name}', kernels={self.kernel_count})"
@@ -2109,7 +2487,7 @@ def setup_gemm_dispatcher(
log(" Validating config...")
validation = validate_kernel_config(config)
if not validation.is_valid:
log(" Auto-correcting configuration...")
log(" WARNING Auto-correcting configuration...")
config, was_modified, corrections = auto_correct_kernel_config(
config, verbose=verbose
)
@@ -2128,13 +2506,13 @@ def setup_gemm_dispatcher(
codegen_result = codegen.generate_from_config(config)
if not codegen_result.success:
log(" Kernel generation: using existing")
log(" WARNING Kernel generation: using existing")
# Step 3: Find matching kernel header
kernel_header = find_matching_kernel_header(config)
result.kernel_header = kernel_header
if not kernel_header:
log(" No matching kernel header found")
log(" WARNING No matching kernel header found")
# Step 4: Load library
log(" Loading library...")
@@ -2188,11 +2566,11 @@ def setup_gemm_dispatcher(
result.error = "Failed to load rebuilt library"
return result
result.lib = lib
log(f" Rebuilt library: {lib.get_kernel_name()}")
log(f" OK Rebuilt library: {lib.get_kernel_name()}")
else:
log(" Rebuild failed, using existing library")
log(" WARNING Rebuild failed, using existing library")
else:
log(" No kernel header found for config, using existing library")
log(" WARNING No kernel header found for config, using existing library")
# Step 5: Create registry and dispatcher
log(" Creating registry and dispatcher...")
@@ -2203,12 +2581,305 @@ def setup_gemm_dispatcher(
dispatcher = Dispatcher(registry=registry, lib=lib)
result.dispatcher = dispatcher
log(f" Ready: {lib.get_kernel_name()}")
log(f" OK Ready: {lib.get_kernel_name()}")
result.success = True
return result
def setup_multiple_gemm_dispatchers(
configs: List[KernelConfig],
registry_name: str = "gemm_registry",
verbose: bool = True,
max_workers: Optional[int] = None,
) -> List[GemmSetupResult]:
"""
Setup multiple GEMM dispatchers in parallel.
Pipeline:
1. Validate + auto-correct each config
2. Parallel codegen: generate .hpp for each config via --config JSON
3. Parallel hipcc: compile each .hpp -> .so
4. Load + wire up each .so into a GemmSetupResult
Each config gets its own .so, so different tile sizes can coexist.
Args:
max_workers: Max parallel processes for codegen/compile (default: cpu_count capped at 8).
"""
import sys
results = [GemmSetupResult(success=False, config=c) for c in configs]
max_workers = max_workers or min(multiprocessing.cpu_count(), 8)
# -- Step 1: Validate & correct ---------------------------------------
valid_configs = []
for i, c in enumerate(configs):
val = validate_kernel_config(c)
if not val.is_valid:
c, modified, corrections = auto_correct_kernel_config(c, verbose=False)
results[i].config = c
results[i].corrections = corrections
valid_configs.append(c)
# -- Step 2: Parallel codegen (one --config JSON per config) ----------
codegen_script = get_codegen_path()
output_dir = get_generated_kernels_dir()
codegen_args = []
for c in valid_configs:
tile_str = c.tile_str
wave_str = f"{c.wave_m}x{c.wave_n}x{c.wave_k}"
warp_str = f"{c.warp_m}x{c.warp_n}x{c.warp_k}"
tile_config_json = {
"tile_config": {
"tile_m": [c.tile_m],
"tile_n": [c.tile_n],
"tile_k": [c.tile_k],
"warp_m": [c.wave_m],
"warp_n": [c.wave_n],
"warp_k": [c.wave_k],
"warp_tile_m": [c.warp_m],
"warp_tile_n": [c.warp_n],
"warp_tile_k": [c.warp_k],
},
"trait_config": {
"pipeline": [c.pipeline],
"epilogue": [c.epilogue],
"scheduler": [c.scheduler],
"pad_m": [c.pad_m],
"pad_n": [c.pad_n],
"pad_k": [c.pad_k],
"persistent": [False],
},
}
hpp_pattern = (
f"gemm_{c.dtype_a}_{c.layout}_{c.pipeline}_{c.epilogue}_{c.scheduler}"
f"_*_{tile_str}_{wave_str}_{warp_str}.hpp"
)
codegen_args.append(
{
"python": sys.executable,
"codegen_script": str(codegen_script),
"output_dir": str(output_dir),
"dtype": c.dtype_a,
"layout": c.layout,
"gpu_target": c.gfx_arch,
"tile_config_json": tile_config_json,
"hpp_glob_pattern": hpp_pattern,
}
)
if verbose:
print(
f"Generating {len(codegen_args)} kernel headers in parallel (workers={max_workers})..."
)
headers: List[Optional[Path]] = [None] * len(valid_configs)
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(_generate_single_kernel_subprocess, a): i
for i, a in enumerate(codegen_args)
}
for future in as_completed(futures):
idx = futures[future]
ok, hdr_str, err = future.result()
if ok and hdr_str:
headers[idx] = Path(hdr_str)
results[idx].kernel_header = Path(hdr_str)
if verbose:
print(
f" OK [{idx}] {valid_configs[idx].tile_str}: {Path(hdr_str).name}"
)
else:
results[idx].error = f"Codegen: {err}"
if verbose:
print(f" FAIL [{idx}] {valid_configs[idx].tile_str}: {err}")
# For configs rejected by arch filter, map to nearest arch-valid header.
fallback_needed = [i for i, h in enumerate(headers) if h is None]
if fallback_needed:
if verbose:
print(
f"Resolving {len(fallback_needed)} configs via arch-valid GEMM catalog..."
)
catalog_cache: Dict[Tuple[str, str, str, str], List[Path]] = {}
for i in fallback_needed:
c = valid_configs[i]
key = (c.gfx_arch, c.dtype_a, c.layout, c.variant)
if key not in catalog_cache:
catalog_dir = (
output_dir
/ "_arch_valid_catalog"
/ (f"{c.gfx_arch}_{c.dtype_a}_{c.layout}_{c.variant}")
)
ok, catalog_headers, err = _generate_arch_valid_gemm_headers(
python_exe=sys.executable,
codegen_script=codegen_script,
output_dir=catalog_dir,
dtype=c.dtype_a,
layout=c.layout,
gpu_target=c.gfx_arch,
variant=c.variant,
)
if not ok:
catalog_headers = []
if verbose:
print(f" FAIL [{i}] catalog generation: {err}")
catalog_cache[key] = catalog_headers
chosen, meta = _select_best_arch_valid_gemm_header(c, catalog_cache[key])
if chosen is None or meta is None:
continue
headers[i] = chosen
results[i].kernel_header = chosen
results[i].error = ""
# Keep Python-side config aligned with the selected kernel header.
valid_configs[i].pipeline = str(meta["pipeline"])
valid_configs[i].epilogue = str(meta["epilogue"])
valid_configs[i].scheduler = str(meta["scheduler"])
valid_configs[i].pad_m = bool(meta["pad_m"])
valid_configs[i].pad_n = bool(meta["pad_n"])
valid_configs[i].pad_k = bool(meta["pad_k"])
valid_configs[i].tile_m = int(meta["tile"][0])
valid_configs[i].tile_n = int(meta["tile"][1])
valid_configs[i].tile_k = int(meta["tile"][2])
valid_configs[i].wave_m = int(meta["wave"][0])
valid_configs[i].wave_n = int(meta["wave"][1])
valid_configs[i].wave_k = int(meta["wave"][2])
valid_configs[i].warp_m = int(meta["warp"][0])
valid_configs[i].warp_n = int(meta["warp"][1])
valid_configs[i].warp_k = int(meta["warp"][2])
results[i].config = valid_configs[i]
if verbose:
print(f" INFO [{i}] mapped to arch-valid header: {chosen.name}")
# -- Step 3: Parallel hipcc compilation -------------------------------
root = get_dispatcher_root()
ck_root = root.parent
build_dir = get_build_dir()
ctypes_source = root / "bindings" / "ctypes" / "gemm_ctypes_lib.cpp"
static_lib = build_dir / "libck_tile_dispatcher.a"
if not ctypes_source.exists() or not static_lib.exists():
for i in range(len(valid_configs)):
if results[i].error == "":
results[
i
].error = "Missing ctypes source or static library for compilation"
return results
compile_jobs = []
compile_index_map = {}
for i, c in enumerate(valid_configs):
hdr = headers[i]
if hdr is None:
continue
lib_name = (
f"libdispatcher_gemm_{c.dtype_a}_{c.layout}_{c.tile_str}_{c.pipeline}.so"
)
lib_path = build_dir / "examples" / lib_name
obj_file = lib_path.with_suffix(".o")
compile_cmd = [
"/opt/rocm/bin/hipcc",
"-c",
"-fPIC",
"-O3",
f"-I{root / 'include'}",
f"-I{ck_root / 'include'}",
f"-I{ck_root}",
f"-I{str(output_dir)}",
"-DCK_TILE_SINGLE_KERNEL_INCLUDE",
f"-include{hdr}",
"-D__HIP_PLATFORM_AMD__",
f"--offload-arch={c.gfx_arch}",
f'-DGFX_ARCH="{c.gfx_arch}"',
"-mllvm",
"-enable-noalias-to-md-conversion=0",
"-Wno-undefined-func-template",
"-Wno-float-equal",
str(ctypes_source),
"-o",
str(obj_file),
]
link_cmd = [
"/opt/rocm/bin/hipcc",
"-shared",
"-fPIC",
f"--offload-arch={c.gfx_arch}",
"--hip-link",
str(obj_file),
str(static_lib),
"-o",
str(lib_path),
]
compile_index_map[len(compile_jobs)] = i
compile_jobs.append(
{
"compile_cmd": compile_cmd,
"link_cmd": link_cmd,
"lib_path": str(lib_path),
}
)
if verbose and compile_jobs:
print(
f"Compiling {len(compile_jobs)} libraries in parallel (workers={max_workers})..."
)
lib_paths: Dict[int, Optional[Path]] = {}
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(_run_hipcc_subprocess, job): j
for j, job in enumerate(compile_jobs)
}
for future in as_completed(futures):
j = futures[future]
i = compile_index_map[j]
ok, lp, err = future.result()
if ok and lp:
lib_paths[i] = Path(lp)
if verbose:
print(f" OK [{i}] {valid_configs[i].tile_str}: {Path(lp).name}")
else:
results[i].error = f"Compile: {err}"
if verbose:
print(f" FAIL [{i}] {valid_configs[i].tile_str}: {err}")
# -- Step 4: Load libraries and create dispatchers --------------------
for i, c in enumerate(valid_configs):
lp = lib_paths.get(i)
if lp is None:
continue
lib = DispatcherLib.load(lp)
if lib is not None and lib.initialize():
results[i].lib = lib
reg = Registry(name=f"{registry_name}_{i}", lib=lib)
reg.register_kernel(c)
results[i].registry = reg
results[i].dispatcher = Dispatcher(registry=reg, lib=lib)
results[i].success = True
else:
results[i].error = "Failed to load compiled library"
if verbose:
ok_count = sum(1 for r in results if r.success)
print(f"Setup complete: {ok_count}/{len(results)} dispatchers ready")
return results
def cleanup_gemm():
"""
Cleanup function to call after running GEMM examples.

View File

@@ -0,0 +1,372 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Shared Python dispatcher utilities for GEMM and grouped convolution.
Extracted from ctypes_utils.py (GEMM) + compile_grouped_conv_examples.py (grouped conv).
Both ctypes_utils.py and grouped_conv_utils.py import from here to
eliminate duplication.
Best-of-both:
- Validation and auto-correction return typed objects (GEMM pattern)
- Colors class with cross-platform ANSI handling (conv pattern)
- Phased output helpers (conv pattern)
- logging module instead of bare print() (shared improvement)
"""
import logging
import shutil
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
log = logging.getLogger(__name__)
# ============================================================================
# Path Configuration
# ============================================================================
def get_dispatcher_root() -> Path:
"""Get the dispatcher root directory (parent of python/)."""
return Path(__file__).parent.parent
def get_ck_root() -> Path:
"""Get the CK root directory (parent of dispatcher/)."""
return get_dispatcher_root().parent
def get_build_dir() -> Path:
"""Get the build directory."""
return get_dispatcher_root() / "build"
def get_generated_kernels_dir() -> Path:
"""Get the generated kernels directory."""
return get_build_dir() / "generated_kernels"
def get_codegen_dir() -> Path:
"""Get the codegen scripts directory."""
return get_dispatcher_root() / "codegen"
# ============================================================================
# Architecture Filter Data
# ============================================================================
_arch_data_cache: Optional[Dict[str, Any]] = None
def detect_gpu_arch(fallback: str = "gfx942") -> str:
"""Detect the GPU architecture from rocminfo. Falls back to the given default."""
import subprocess
try:
out = subprocess.check_output(
["rocminfo"], text=True, stderr=subprocess.DEVNULL
)
for line in out.splitlines():
if "Name:" in line and "gfx" in line:
return line.split()[-1].strip()
except Exception:
pass
return fallback
def get_arch_filter_data() -> Dict[str, Any]:
"""Load arch filter data from arch_specs_generated if available.
Returns dict with keys: trait_unsupported, warp_combos,
warp_tile_combos, supported_archs.
"""
global _arch_data_cache
if _arch_data_cache is not None:
return _arch_data_cache
codegen_dir = get_dispatcher_root() / "codegen"
sys.path.insert(0, str(codegen_dir))
try:
from arch_specs_generated import (
TRAIT_UNSUPPORTED_COMBINATIONS,
WARP_SUPPORTED_COMBINATIONS,
WARP_TILE_SUPPORTED_COMBINATIONS,
get_supported_archs,
)
_arch_data_cache = {
"trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS,
"warp_combos": WARP_SUPPORTED_COMBINATIONS,
"warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS,
"supported_archs": get_supported_archs(),
}
except ImportError:
_arch_data_cache = {
"trait_unsupported": {
("compv3", "cshuffle", "interwave"),
("compv3", "default", "interwave"),
("compv4", "cshuffle", "interwave"),
("compv4", "default", "interwave"),
},
"warp_combos": {
"gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
"gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
},
"warp_tile_combos": {
"gfx942": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]},
"gfx90a": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]},
},
"supported_archs": ["gfx90a", "gfx942", "gfx950"],
}
return _arch_data_cache
# ============================================================================
# Validation Result
# ============================================================================
@dataclass
class ValidationResultBase:
"""Result of kernel config validation (shared base for GEMM and conv)."""
is_valid: bool
errors: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
suggested_fixes: Dict[str, Any] = field(default_factory=dict)
def print_result(self, indent: str = " "):
if self.is_valid:
print(f"{indent}OK Configuration valid")
else:
print(f"{indent}WARNING Configuration has issues:")
for err in self.errors:
print(f"{indent} - {err}")
if self.warnings:
for warn in self.warnings:
print(f"{indent} Warning: {warn}")
if self.suggested_fixes:
print(f"{indent} Suggested fixes:")
for key, val in self.suggested_fixes.items():
print(f"{indent} {key}: {val}")
# ============================================================================
# Validation Helpers
# ============================================================================
def validate_wave_config(wave_cfg: List[int], arch: str) -> Tuple[bool, str]:
"""Validate a [wave_m, wave_n, wave_k] config for *arch*.
Returns (is_valid, error_message). Empty string on success.
"""
data = get_arch_filter_data()
valid_waves = data["warp_combos"].get(arch, [[2, 2, 1]])
if wave_cfg in valid_waves:
return True, ""
valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in valid_waves)
return (
False,
f"Unsupported wave configuration {wave_cfg} for {arch}. "
f"Valid wave configs: {valid_str}",
)
def validate_warp_tile_config(
warp_cfg: List[int], arch: str, dtype: str
) -> Tuple[bool, str]:
"""Validate a [warp_m, warp_n, warp_k] config for *arch*/*dtype*.
Returns (is_valid, error_message). Empty string on success.
"""
data = get_arch_filter_data()
acc = "int32" if dtype == "int8" else "fp32"
dtype_key = f"{dtype}_{dtype}_{acc}"
valid_tiles = (
data["warp_tile_combos"]
.get(arch, {})
.get(dtype_key, [[32, 32, 16], [16, 16, 16]])
)
if warp_cfg in valid_tiles:
return True, ""
valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in valid_tiles[:5])
return (
False,
f"Unsupported warp tile {warp_cfg} for {arch}/{dtype}. "
f"Valid warp tiles: {valid_str}",
)
def validate_trait_combo(
pipeline: str, epilogue: str, scheduler: str
) -> Tuple[bool, str]:
"""Validate a (pipeline, epilogue, scheduler) combination.
Returns (is_valid, error_message). Empty string on success.
"""
data = get_arch_filter_data()
combo = (pipeline, epilogue, scheduler)
if combo in data["trait_unsupported"]:
return (
False,
f"Unsupported trait combination: pipeline={pipeline}, "
f"epilogue={epilogue}, scheduler={scheduler}",
)
return True, ""
# ============================================================================
# Auto-Correction Helpers
# ============================================================================
def auto_correct_wave(wave_cfg: List[int], arch: str) -> List[int]:
"""Return the first valid wave config for *arch*.
If *wave_cfg* is already valid, returns it unchanged.
"""
data = get_arch_filter_data()
valid_waves = data["warp_combos"].get(arch, [[2, 2, 1]])
if wave_cfg in valid_waves:
return wave_cfg
return valid_waves[0] if valid_waves else [2, 2, 1]
def auto_correct_trait(pipeline: str, scheduler: str) -> Tuple[str, str]:
"""Return a corrected (pipeline, scheduler) pair.
If the compute pipeline doesn't support interwave, switch to intrawave.
"""
data = get_arch_filter_data()
for epilogue in ("cshuffle", "default"):
if (pipeline, epilogue, scheduler) in data["trait_unsupported"]:
return pipeline, "intrawave"
return pipeline, scheduler
# ============================================================================
# Colors (adopted from compile_grouped_conv_examples.py -- cross-platform)
# ============================================================================
class Colors:
"""Cross-platform ANSI color support.
Respects sys.platform (no ANSI on Windows) and isatty() check so
piped/redirected output stays clean.
"""
_GREEN = "\033[0;32m"
_YELLOW = "\033[1;33m"
_RED = "\033[0;31m"
_CYAN = "\033[0;36m"
_BOLD = "\033[1m"
_NC = "\033[0m"
@classmethod
def _use_color(cls) -> bool:
return (
sys.platform != "win32"
and hasattr(sys.stdout, "isatty")
and sys.stdout.isatty()
)
@classmethod
def green(cls, text: str) -> str:
if cls._use_color():
return f"{cls._GREEN}{text}{cls._NC}"
return text
@classmethod
def red(cls, text: str) -> str:
if cls._use_color():
return f"{cls._RED}{text}{cls._NC}"
return text
@classmethod
def yellow(cls, text: str) -> str:
if cls._use_color():
return f"{cls._YELLOW}{text}{cls._NC}"
return text
@classmethod
def cyan(cls, text: str) -> str:
if cls._use_color():
return f"{cls._CYAN}{text}{cls._NC}"
return text
@classmethod
def bold(cls, text: str) -> str:
if cls._use_color():
return f"{cls._BOLD}{text}{cls._NC}"
return text
# ============================================================================
# Phased Output Helpers
# ============================================================================
def print_phase(number: int, description: str) -> None:
"""Print a phase header (e.g. 'Phase 1: Codegen')."""
print(f"\n{'=' * 60}")
print(f" Phase {number}: {description}")
print(f"{'=' * 60}")
def print_success(message: str) -> None:
"""Print a success message."""
print(f" OK {Colors.green(message)}")
def print_error(message: str) -> None:
"""Print an error message."""
print(f" FAIL {Colors.red(message)}")
def print_info(message: str) -> None:
"""Print an info message."""
print(f" {Colors.cyan(message)}")
# ============================================================================
# Cleanup Helpers
# ============================================================================
def cleanup_generated_kernels(gen_dir: Optional[Path] = None) -> None:
"""Remove generated kernel directory if it exists."""
if gen_dir is None:
gen_dir = get_generated_kernels_dir()
if gen_dir.exists():
shutil.rmtree(gen_dir, ignore_errors=True)
log.info("Cleaned up generated kernels at %s", gen_dir)
# ============================================================================
# Tool Helpers
# ============================================================================
def find_hipcc() -> Optional[str]:
"""Find the hipcc compiler."""
import os
candidates = [
os.environ.get("HIPCC"),
"/opt/rocm/bin/hipcc",
shutil.which("hipcc"),
]
for path in candidates:
if path and os.path.isfile(path):
return path
return None

File diff suppressed because it is too large Load Diff

View File

@@ -94,17 +94,17 @@ def find_hipcc() -> str:
def extract_conv_kernel_declarations(source_file: Path) -> list:
"""Extract CONVOLUTION kernel declarations from C++ source file.
"""Extract GROUPED CONVOLUTION kernel declarations from C++ source file.
Supports DECL_CONV_KERNEL_SET macro with ConvSig/ConvAlgo pattern.
Supports DECL_GROUPED_CONV_KERNEL_SET macro with ConvSig/ConvAlgo pattern.
Extracts all parameters: dtype, layout, conv_type, dims, tile, wave, warp, pipeline, scheduler.
"""
content = source_file.read_text()
declarations = []
seen = set()
# Pattern: DECL_CONV_KERNEL_SET(name, .add(...).add(...))
set_pattern = r"DECL_CONV_KERNEL_SET\s*\(\s*(\w+)\s*,([^;]+)\)"
# Pattern: DECL_GROUPED_CONV_KERNEL_SET(name, .add(...).add(...))
set_pattern = r"DECL_GROUPED_CONV_KERNEL_SET\s*\(\s*(\w+)\s*,([^;]+)\)"
for match in re.finditer(set_pattern, content, re.DOTALL):
set_name = match.group(1)
@@ -396,24 +396,23 @@ def expand_conv_declaration_with_arch_filter(decl: dict, arch: str = "gfx942") -
def generate_conv_kernels(declarations: list, gpu_target: str = "gfx942") -> int:
"""Generate convolution kernels using unified_conv_codegen."""
"""Generate grouped convolution kernels using unified_grouped_conv_codegen."""
kernel_dir = get_generated_kernels_dir()
kernel_dir.mkdir(parents=True, exist_ok=True)
# Import conv codegen
codegen_dir = get_dispatcher_root() / "codegen"
sys.path.insert(0, str(codegen_dir))
try:
from unified_conv_codegen import (
UnifiedConvCodegen,
ConvKernelConfig,
ConvVariant,
from unified_grouped_conv_codegen import (
UnifiedGroupedConvCodegen as UnifiedConvCodegen,
GroupedConvKernelConfig as ConvKernelConfig,
GroupedConvVariant as ConvVariant,
TileConfig,
TraitConfig,
GroupedConvTraitConfig as TraitConfig,
)
except ImportError as e:
print_error(f" Failed to import conv codegen: {e}")
print_error(f" Failed to import grouped conv codegen: {e}")
return 0
codegen = UnifiedConvCodegen(kernel_dir)
@@ -1564,9 +1563,9 @@ def build_exact_conv_kernel_filename(decl: dict) -> str:
if conv_type == "forward":
type_prefix = "fwd"
elif conv_type == "bwd_data":
type_prefix = "bwdd"
type_prefix = "bwd_data"
elif conv_type == "bwd_weight":
type_prefix = "bwdw"
type_prefix = "bwd_weight"
else:
type_prefix = conv_type
@@ -1601,9 +1600,9 @@ def generate_specific_conv_kernel(decl: dict, gpu_target: str = "gfx942") -> boo
else:
variant = "forward"
# Use unified_conv_codegen
# Use unified_grouped_conv_codegen
codegen_dir = get_dispatcher_root() / "codegen"
codegen_script = codegen_dir / "unified_conv_codegen.py"
codegen_script = codegen_dir / "unified_grouped_conv_codegen.py"
output_dir = get_generated_kernels_dir()
cmd = [
@@ -1661,9 +1660,9 @@ def find_conv_kernel_header(decl: dict, gpu_target: str = "gfx942") -> Path:
if conv_type == "forward":
type_prefix = "fwd"
elif conv_type == "bwd_data":
type_prefix = "bwdd"
type_prefix = "bwd_data"
elif conv_type == "bwd_weight":
type_prefix = "bwdw"
type_prefix = "bwd_weight"
else:
type_prefix = conv_type
@@ -1865,7 +1864,9 @@ In your C++ code, declare kernels like:
if not gemm_declarations and not conv_declarations:
print_error(" No kernel declarations found!")
print(" Add DECL_KERNEL_SET for GEMM or DECL_CONV_KERNEL_SET for Conv")
print(
" Add DECL_KERNEL_SET for GEMM or DECL_GROUPED_CONV_KERNEL_SET for Grouped Conv"
)
return 1
# Handle GEMM declarations
@@ -1913,7 +1914,7 @@ In your C++ code, declare kernels like:
is_valid, error_msg = validate_kernel_config(decl, arch)
if not is_valid:
print(f"\n Invalid configuration: {decl_name}")
print(f"\n WARNING Invalid configuration: {decl_name}")
# Parse the error and show specific auto-corrections
corrections = []
@@ -1926,7 +1927,7 @@ In your C++ code, declare kernels like:
decl["wave_m"] = -1
decl["wave_n"] = -1
corrections.append(
f"wave: {original_values['wave']} [wildcard expansion]"
f"wave: {original_values['wave']} -> [wildcard expansion]"
)
if "warp tile" in error_msg.lower():
@@ -1936,7 +1937,7 @@ In your C++ code, declare kernels like:
decl["warp_m"] = -1
decl["warp_n"] = -1
corrections.append(
f"warp_tile: {original_values['warp']} [wildcard expansion]"
f"warp_tile: {original_values['warp']} -> [wildcard expansion]"
)
if "trait combination" in error_msg.lower():
@@ -1945,16 +1946,16 @@ In your C++ code, declare kernels like:
decl["pipeline"] = "*"
decl["scheduler"] = "*"
corrections.append(
f"pipeline: {original_values['pipeline']} [wildcard expansion]"
f"pipeline: {original_values['pipeline']} -> [wildcard expansion]"
)
corrections.append(
f"scheduler: {original_values['scheduler']} [wildcard expansion]"
f"scheduler: {original_values['scheduler']} -> [wildcard expansion]"
)
# Print the auto-corrections
print(" AUTO-CORRECTION:")
for corr in corrections:
print(f" {corr}")
print(f" - {corr}")
auto_corrections.append((decl_name, corrections))
invalid_count += 1
@@ -1962,15 +1963,15 @@ In your C++ code, declare kernels like:
if invalid_count > 0:
print(
f"\n {invalid_count} invalid config(s) auto-corrected via wildcard expansion"
f"\n WARNING {invalid_count} invalid config(s) auto-corrected via wildcard expansion"
)
if wildcard_count > 0:
print(
f" {len(gemm_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)"
f" OK {len(gemm_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)"
)
else:
print(f" All {len(gemm_declarations)} configurations valid")
print(f" OK All {len(gemm_declarations)} configurations valid")
# Expand GEMM declarations (for wildcards)
print("\n Expanding wildcards to valid configurations...")
@@ -1994,7 +1995,7 @@ In your C++ code, declare kernels like:
wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]"
warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]"
print(
f" wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}"
f" -> wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}"
)
if len(expanded) > 3:
print(f" ... and {len(expanded) - 3} more")
@@ -2002,11 +2003,11 @@ In your C++ code, declare kernels like:
exp = expanded[0]
wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]"
warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]"
print(f" {decl_name}: wave={wave_str}, warp={warp_str}")
print(f" {decl_name}: -> wave={wave_str}, warp={warp_str}")
if len(expanded_gemm) > len(gemm_declarations):
print(
f"\n Total: {len(gemm_declarations)} declarations {len(expanded_gemm)} configurations"
f"\n Total: {len(gemm_declarations)} declarations -> {len(expanded_gemm)} configurations"
)
gemm_declarations = expanded_gemm
@@ -2054,7 +2055,7 @@ In your C++ code, declare kernels like:
is_valid, error_msg = validate_conv_kernel_config(decl, arch)
if not is_valid:
print(f"\n Invalid conv configuration: {decl_name}")
print(f"\n WARNING Invalid conv configuration: {decl_name}")
# Parse the error and show specific auto-corrections
corrections = []
@@ -2067,7 +2068,7 @@ In your C++ code, declare kernels like:
decl["wave_m"] = -1
decl["wave_n"] = -1
corrections.append(
f"wave: {original_values['wave']} [wildcard expansion]"
f"wave: {original_values['wave']} -> [wildcard expansion]"
)
if "warp tile" in error_msg.lower():
@@ -2077,7 +2078,7 @@ In your C++ code, declare kernels like:
decl["warp_m"] = -1
decl["warp_n"] = -1
corrections.append(
f"warp_tile: {original_values['warp']} [wildcard expansion]"
f"warp_tile: {original_values['warp']} -> [wildcard expansion]"
)
if "trait combination" in error_msg.lower():
@@ -2086,16 +2087,16 @@ In your C++ code, declare kernels like:
decl["pipeline"] = "*"
decl["scheduler"] = "*"
corrections.append(
f"pipeline: {original_values['pipeline']} [wildcard expansion]"
f"pipeline: {original_values['pipeline']} -> [wildcard expansion]"
)
corrections.append(
f"scheduler: {original_values['scheduler']} [wildcard expansion]"
f"scheduler: {original_values['scheduler']} -> [wildcard expansion]"
)
# Print the auto-corrections
print(" AUTO-CORRECTION:")
for corr in corrections:
print(f" {corr}")
print(f" - {corr}")
auto_corrections.append((decl_name, corrections))
invalid_count += 1
@@ -2103,15 +2104,15 @@ In your C++ code, declare kernels like:
if invalid_count > 0:
print(
f"\n {invalid_count} invalid config(s) auto-corrected via wildcard expansion"
f"\n WARNING {invalid_count} invalid config(s) auto-corrected via wildcard expansion"
)
if wildcard_count > 0:
print(
f" {len(conv_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)"
f" OK {len(conv_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)"
)
else:
print(f" All {len(conv_declarations)} configurations valid")
print(f" OK All {len(conv_declarations)} configurations valid")
# Expand Conv declarations (for wildcards)
print("\n Expanding wildcards to valid configurations...")
@@ -2134,7 +2135,7 @@ In your C++ code, declare kernels like:
wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]"
warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]"
print(
f" wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}"
f" -> wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}"
)
if len(expanded) > 3:
print(f" ... and {len(expanded) - 3} more")
@@ -2142,11 +2143,11 @@ In your C++ code, declare kernels like:
exp = expanded[0]
wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]"
warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]"
print(f" {decl_name}: wave={wave_str}, warp={warp_str}")
print(f" {decl_name}: -> wave={wave_str}, warp={warp_str}")
if len(expanded_conv) > len(conv_declarations):
print(
f"\n Total: {len(conv_declarations)} declarations {len(expanded_conv)} configurations"
f"\n Total: {len(conv_declarations)} declarations -> {len(expanded_conv)} configurations"
)
conv_declarations = expanded_conv

View File

@@ -0,0 +1,882 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Self-contained build script for C++ grouped convolution examples.
Parses DECL_GROUPED_CONV_KERNEL_SET declarations from source files,
generates the needed kernels, and compiles the example.
Includes validation and auto-correction via wildcard expansion.
Usage:
python3 compile_grouped_conv_examples.py examples/grouped_conv/cpp/02_grouped_conv_forward.cpp
python3 compile_grouped_conv_examples.py examples/grouped_conv/cpp/03_grouped_conv_validation.cpp --no-compile
"""
import argparse
import os
import re
import subprocess
import sys
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Optional
# Setup paths
SCRIPT_DIR = Path(__file__).parent.resolve()
DISPATCHER_DIR = SCRIPT_DIR.parent
CK_ROOT = DISPATCHER_DIR.parent
sys.path.insert(0, str(DISPATCHER_DIR / "python"))
sys.path.insert(0, str(DISPATCHER_DIR / "codegen"))
from dispatcher_common import ( # noqa: E402
print_phase,
print_success,
print_error,
print_info,
find_hipcc,
get_arch_filter_data,
get_build_dir,
get_ck_root,
get_dispatcher_root,
get_generated_kernels_dir,
)
def extract_grouped_conv_declarations(source_file: Path) -> list:
"""Extract DECL_GROUPED_CONV_KERNEL_SET declarations from C++ source."""
content = source_file.read_text()
declarations = []
# Pattern: DECL_GROUPED_CONV_KERNEL_SET(name, .add(...).add(...))
# Find all DECL_GROUPED_CONV_KERNEL_SET blocks by matching parentheses
pattern_start = r"DECL_GROUPED_CONV_KERNEL_SET\s*\(\s*(\w+)\s*,"
for match in re.finditer(pattern_start, content):
set_name = match.group(1)
start_pos = match.end()
# Find matching closing paren by counting parens
paren_count = 1 # We're already inside the first paren
end_pos = start_pos
for i, c in enumerate(content[start_pos:]):
if c == "(":
paren_count += 1
elif c == ")":
paren_count -= 1
if paren_count == 0:
end_pos = start_pos + i
break
set_body = content[start_pos:end_pos]
# Pattern 1: Simple add("dtype", "layout", "conv_type", tile_k, tile_c)
simple_add = (
r'\.add\s*\(\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)'
)
for add_match in re.finditer(simple_add, set_body):
conv_type = add_match.group(3)
default_pipeline = (
"compv3" if conv_type in ("bwd_data", "bwd_weight") else "compv4"
)
declarations.append(
{
"set": set_name,
"dtype": add_match.group(1),
"layout": add_match.group(2),
"conv_type": conv_type,
"tile_k": int(add_match.group(4)),
"tile_c": int(add_match.group(5)),
"num_dims": 2,
"pipeline": default_pipeline,
"scheduler": "intrawave",
"wave_m": 2,
"wave_n": 2,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
"arch": "gfx942",
}
)
# Pattern 2: Full ConvSig()/ConvAlgo() specification
# Find all .add( positions that start with ConvSig()
full_add = r"\.add\s*\(\s*ConvSig\(\)"
add_positions = [m.start() for m in re.finditer(full_add, set_body)]
for pos in add_positions:
# Find matching closing paren by counting parens
paren_count = 0
in_add = False
end = pos
for i, c in enumerate(set_body[pos:]):
if c == "(":
paren_count += 1
in_add = True
elif c == ")":
paren_count -= 1
if in_add and paren_count == 0:
end = pos + i + 1
break
add_str = set_body[pos:end]
# Extract signature part (between ConvSig() and ConvAlgo())
sig_match = re.search(r"ConvSig\(\)(.*?)ConvAlgo\(\)", add_str, re.DOTALL)
if not sig_match:
continue
sig_str = sig_match.group(1)
# Extract algorithm part (between ConvAlgo() and arch string)
algo_match = re.search(
r'ConvAlgo\(\)(.*?),\s*"(\w+)"\s*\)', add_str, re.DOTALL
)
if not algo_match:
continue
algo_str = algo_match.group(1)
arch = algo_match.group(2)
# Parse signature
dtype = "fp16"
dtype_match = re.search(r'\.dtype\s*\(\s*"(\w+)"', sig_str)
if dtype_match:
dtype = dtype_match.group(1)
layout = "nhwgc"
layout_match = re.search(r'\.layout\s*\(\s*"(\w+)"', sig_str)
if layout_match:
layout = layout_match.group(1)
conv_type = "forward"
conv_type_match = re.search(r'\.conv_type\s*\(\s*"(\w+)"', sig_str)
if conv_type_match:
conv_type = conv_type_match.group(1)
num_dims = 2
dims_match = re.search(r"\.dims\s*\(\s*(\d+)", sig_str)
if dims_match:
num_dims = int(dims_match.group(1))
# Parse algorithm
tile_k, tile_c = 128, 128
tile_match = re.search(
r"\.tile\s*\(\s*\d+\s*,\s*(\d+)\s*,\s*(\d+)", algo_str
)
if tile_match:
tile_k = int(tile_match.group(1))
tile_c = int(tile_match.group(2))
wave_m, wave_n, wave_k = 2, 2, 1
wave_match = re.search(
r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?", algo_str
)
if wave_match:
wave_m = int(wave_match.group(1))
wave_n = int(wave_match.group(2))
wave_k = int(wave_match.group(3) or 1)
warp_m, warp_n, warp_k = 32, 32, 16
warp_match = re.search(
r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?", algo_str
)
if warp_match:
warp_m = int(warp_match.group(1))
warp_n = int(warp_match.group(2))
warp_k = int(warp_match.group(3) or 16)
pipeline = "compv4"
pipeline_match = re.search(r'\.pipeline\s*\(\s*"(\w+)"', algo_str)
if pipeline_match:
pipeline = pipeline_match.group(1)
scheduler = "intrawave"
scheduler_match = re.search(r'\.scheduler\s*\(\s*"(\w+)"', algo_str)
if scheduler_match:
scheduler = scheduler_match.group(1)
# Parse additional parameters
vector_a, vector_b, vector_c = 4, 8, 8
vector_match = re.search(
r"\.vector_sizes\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)", algo_str
)
if vector_match:
vector_a = int(vector_match.group(1))
vector_b = int(vector_match.group(2))
vector_c = int(vector_match.group(3))
block_per_cu = 1
block_per_cu_match = re.search(r"\.block_per_cu\s*\(\s*(\d+)", algo_str)
if block_per_cu_match:
block_per_cu = int(block_per_cu_match.group(1))
memory_op = "set"
memory_op_match = re.search(r'\.memory_op\s*\(\s*"(\w+)"', algo_str)
if memory_op_match:
memory_op = memory_op_match.group(1)
epilogue = "cshuffle"
epilogue_match = re.search(r'\.epilogue\s*\(\s*"(\w+)"', algo_str)
if epilogue_match:
epilogue = epilogue_match.group(1)
# Parse num_wave_groups (for V5 pipeline)
num_wave_groups = 1
nwg_match = re.search(r"\.num_wave_groups\s*\(\s*(\d+)", algo_str)
if nwg_match:
num_wave_groups = int(nwg_match.group(1))
# Parse num_groups_to_merge (for merged group grouped convolution)
num_groups_to_merge = 1
ngm_match = re.search(r"\.num_groups_to_merge\s*\(\s*(\d+)", algo_str)
if ngm_match:
num_groups_to_merge = int(ngm_match.group(1))
# Parse double_smem_buffer (for V4 pipeline)
double_smem_buffer = False
dsb_match = re.search(
r"\.double_smem_buffer\s*\(\s*(true|false)", algo_str, re.I
)
if dsb_match:
double_smem_buffer = dsb_match.group(1).lower() == "true"
# Parse padding flags
pad_m, pad_n, pad_k = True, True, True
padding_match = re.search(
r"\.padding\s*\(\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)",
algo_str,
re.I,
)
if padding_match:
pad_m = padding_match.group(1).lower() == "true"
pad_n = padding_match.group(2).lower() == "true"
pad_k = padding_match.group(3).lower() == "true"
declarations.append(
{
"set": set_name,
"dtype": dtype,
"layout": layout,
"conv_type": conv_type,
"tile_k": tile_k,
"tile_c": tile_c,
"num_dims": num_dims,
"pipeline": pipeline,
"scheduler": scheduler,
"wave_m": wave_m,
"wave_n": wave_n,
"wave_k": wave_k,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_k": warp_k,
"vector_a": vector_a,
"vector_b": vector_b,
"vector_c": vector_c,
"block_per_cu": block_per_cu,
"memory_op": memory_op,
"epilogue": epilogue,
"num_wave_groups": num_wave_groups,
"num_groups_to_merge": num_groups_to_merge,
"double_smem_buffer": double_smem_buffer,
"pad_m": pad_m,
"pad_n": pad_n,
"pad_k": pad_k,
"arch": arch,
}
)
return declarations
# =============================================================================
# VALIDATION AND AUTO-CORRECTION
# =============================================================================
def is_grouped_conv_wildcard_declaration(decl: dict) -> bool:
"""Check if a declaration uses wildcards (-1 or '*')."""
wildcard_fields = ["wave_m", "wave_n", "warp_m", "warp_n", "pipeline", "scheduler"]
for field in wildcard_fields:
val = decl.get(field)
if val == -1 or val == "*":
return True
return False
def validate_grouped_conv_kernel_config(decl: dict, arch: str = "gfx942") -> tuple:
"""Validate a grouped conv kernel configuration against known supported combinations.
Returns: (is_valid, error_message)
"""
# Skip validation for wildcards - expansion will filter invalid combos
if is_grouped_conv_wildcard_declaration(decl):
return (True, None)
arch_data = get_arch_filter_data()
pipeline = decl.get("pipeline", "compv4")
scheduler = decl.get("scheduler", "intrawave")
dtype = decl.get("dtype", "fp16")
wave_m = decl.get("wave_m", 2)
wave_n = decl.get("wave_n", 2)
wave_k = decl.get("wave_k", 1)
warp_m = decl.get("warp_m", 32)
warp_n = decl.get("warp_n", 32)
warp_k = decl.get("warp_k", 16)
errors = []
# Check trait combination (pipeline, epilogue, scheduler)
combo = (pipeline, "cshuffle", scheduler)
if combo in arch_data["trait_unsupported"]:
errors.append(
f"Unsupported trait combination: pipeline={pipeline}, scheduler={scheduler}\n"
f" Valid schedulers for {pipeline}: intrawave"
)
# Check wave configuration for this arch
warp_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]])
wave_cfg = [wave_m, wave_n, wave_k]
if wave_cfg not in warp_combos:
valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_combos)
errors.append(
f"Unsupported wave configuration [{wave_m},{wave_n},{wave_k}] for {arch}\n"
f" Valid wave configs: {valid_str}"
)
# Check warp tile configuration for this arch and dtype
acc_dtype = "int32" if dtype == "int8" else "fp32"
dtype_key = f"{dtype}_{dtype}_{acc_dtype}"
warp_tile_combos = (
arch_data["warp_tile_combos"]
.get(arch, {})
.get(dtype_key, [[32, 32, 16], [16, 16, 16], [16, 16, 32]])
)
warp_cfg = [warp_m, warp_n, warp_k]
if warp_cfg not in warp_tile_combos:
valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_tile_combos[:5])
errors.append(
f"Unsupported warp tile [{warp_m},{warp_n},{warp_k}] for {arch}/{dtype}\n"
f" Valid warp tiles: {valid_str}"
)
# Check arch is supported
if arch not in arch_data["supported_archs"]:
errors.append(
f"Unsupported architecture: {arch}\n"
f" Supported: {', '.join(arch_data['supported_archs'])}"
)
if errors:
return (False, "\n".join(errors))
return (True, None)
def expand_grouped_conv_declaration_with_arch_filter(
decl: dict, arch: str = "gfx942"
) -> list:
"""Expand a grouped conv declaration with wildcards into valid configurations.
Wildcards:
- wave_m/wave_n = -1: Try all valid wave configs for this arch
- warp_m/warp_n = -1: Try all valid warp tiles for this arch/dtype
- pipeline/scheduler = "*": Try all valid combinations
Returns a list of fully-specified declarations.
"""
arch_data = get_arch_filter_data()
dtype = decl.get("dtype", "fp16")
# Get valid combinations for this arch
valid_wave_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]])
acc_dtype = "int32" if dtype == "int8" else "fp32"
dtype_key = f"{dtype}_{dtype}_{acc_dtype}"
valid_warp_tiles = (
arch_data["warp_tile_combos"]
.get(arch, {})
.get(dtype_key, [[32, 32, 16], [16, 16, 16]])
)
# Valid pipelines and schedulers
valid_pipelines = ["compv3", "compv4"]
valid_schedulers = ["intrawave"] # interwave often unsupported
# Determine which fields need expansion
expand_wave = decl.get("wave_m", 2) == -1 or decl.get("wave_n", 2) == -1
expand_warp = decl.get("warp_m", 32) == -1 or decl.get("warp_n", 32) == -1
expand_pipeline = decl.get("pipeline", "compv4") == "*"
expand_scheduler = decl.get("scheduler", "intrawave") == "*"
# Build combinations
wave_options = (
valid_wave_combos
if expand_wave
else [[decl.get("wave_m", 2), decl.get("wave_n", 2), decl.get("wave_k", 1)]]
)
warp_options = (
valid_warp_tiles
if expand_warp
else [[decl.get("warp_m", 32), decl.get("warp_n", 32), decl.get("warp_k", 16)]]
)
pipeline_options = (
valid_pipelines if expand_pipeline else [decl.get("pipeline", "compv4")]
)
scheduler_options = (
valid_schedulers if expand_scheduler else [decl.get("scheduler", "intrawave")]
)
expanded = []
for wave in wave_options:
for warp in warp_options:
for pipeline in pipeline_options:
for scheduler in scheduler_options:
# Skip known invalid combinations
if (pipeline, "cshuffle", scheduler) in arch_data[
"trait_unsupported"
]:
continue
new_decl = decl.copy()
new_decl["wave_m"] = wave[0]
new_decl["wave_n"] = wave[1]
new_decl["wave_k"] = wave[2]
new_decl["warp_m"] = warp[0]
new_decl["warp_n"] = warp[1]
new_decl["warp_k"] = warp[2]
new_decl["pipeline"] = pipeline
new_decl["scheduler"] = scheduler
expanded.append(new_decl)
# If no valid expansions, return original (will fail validation later)
if not expanded:
return [decl]
# Return first valid config (or all if needed)
return expanded[:1] # Just use first valid config for grouped conv
def validate_and_expand_grouped_conv_declarations(
declarations: list, arch: str, verbose: bool = False
) -> list:
"""Validate declarations and auto-correct invalid ones via wildcard expansion."""
print(f"\n Validating against {arch} arch filter...")
wildcard_count = 0
invalid_count = 0
auto_corrections = []
for decl in declarations:
decl_arch = decl.get("arch", arch)
decl_name = (
f"{decl['dtype']}_{decl['conv_type']}_{decl['tile_k']}x{decl['tile_c']}"
)
# Check for wildcards
if is_grouped_conv_wildcard_declaration(decl):
wildcard_count += 1
continue
is_valid, error_msg = validate_grouped_conv_kernel_config(decl, decl_arch)
if not is_valid:
print(f"\n WARNING Invalid grouped conv configuration: {decl_name}")
# Parse the error and show specific auto-corrections
corrections = []
original_values = {}
if "wave configuration" in error_msg.lower():
original_values["wave"] = (
f"[{decl.get('wave_m', 2)}, {decl.get('wave_n', 2)}, {decl.get('wave_k', 1)}]"
)
decl["wave_m"] = -1
decl["wave_n"] = -1
corrections.append(
f"wave: {original_values['wave']} -> [wildcard expansion]"
)
if "warp tile" in error_msg.lower():
original_values["warp"] = (
f"[{decl.get('warp_m', 32)}, {decl.get('warp_n', 32)}, {decl.get('warp_k', 16)}]"
)
decl["warp_m"] = -1
decl["warp_n"] = -1
corrections.append(
f"warp_tile: {original_values['warp']} -> [wildcard expansion]"
)
if "trait combination" in error_msg.lower():
original_values["pipeline"] = decl.get("pipeline", "compv4")
original_values["scheduler"] = decl.get("scheduler", "intrawave")
decl["pipeline"] = "*"
decl["scheduler"] = "*"
corrections.append(
f"pipeline: {original_values['pipeline']} -> [wildcard expansion]"
)
corrections.append(
f"scheduler: {original_values['scheduler']} -> [wildcard expansion]"
)
# Print the auto-corrections
print(" AUTO-CORRECTION:")
for corr in corrections:
print(f" - {corr}")
auto_corrections.append((decl_name, corrections))
invalid_count += 1
wildcard_count += 1
if invalid_count > 0:
print(
f"\n WARNING {invalid_count} invalid config(s) auto-corrected via wildcard expansion"
)
if wildcard_count > 0:
print(
f" OK {len(declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)"
)
else:
print(f" OK All {len(declarations)} configurations valid")
# Expand wildcards
print("\n Expanding wildcards to valid configurations...")
expanded_declarations = []
for decl in declarations:
decl_arch = decl.get("arch", arch)
decl_name = (
f"{decl['dtype']}_{decl['conv_type']}_{decl['tile_k']}x{decl['tile_c']}"
)
expanded = expand_grouped_conv_declaration_with_arch_filter(decl, decl_arch)
expanded_declarations.extend(expanded)
if len(expanded) > 1:
print(
f" {decl_name}: expanded to {len(expanded)} valid configurations"
)
for exp in expanded[:3]:
wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]"
warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]"
print(
f" -> wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}"
)
if len(expanded) > 3:
print(f" ... and {len(expanded) - 3} more")
elif is_grouped_conv_wildcard_declaration(decl) and len(expanded) == 1:
exp = expanded[0]
wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]"
warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]"
print(f" {decl_name}: -> wave={wave_str}, warp={warp_str}")
if len(expanded_declarations) != len(declarations):
print(
f"\n Total: {len(declarations)} declarations -> {len(expanded_declarations)} configurations"
)
return expanded_declarations
def _generate_single_grouped_conv_kernel(args: tuple) -> tuple:
"""Generate one grouped conv kernel (picklable for ProcessPoolExecutor).
Args: (decl, output_dir_str, gpu_target)
Returns: (idx, filepath_str or None, error_str or None)
"""
decl, output_dir_str, gpu_target = args
output_dir = Path(output_dir_str)
idx = decl.get("_idx", 0)
try:
from codegen_common import TileConfig
from unified_grouped_conv_codegen import (
GroupedConvKernelConfig,
GroupedConvTraitConfig,
GroupedConvVariant,
UnifiedGroupedConvCodegen,
)
# Map conv_type to variant
variant = GroupedConvVariant.FORWARD
if decl["conv_type"] == "bwd_data":
variant = GroupedConvVariant.BACKWARD_DATA
elif decl["conv_type"] == "bwd_weight":
variant = GroupedConvVariant.BACKWARD_WEIGHT
pipeline = decl.get("pipeline", "compv4")
adj_tile_k = 64 * 2 if pipeline == "compv4" else 64
# Create tile config (tile_m=tile_k, tile_n=tile_c for conv GEMM view)
tile = TileConfig(
tile_m=decl["tile_k"],
tile_n=decl["tile_c"],
tile_k=adj_tile_k,
warp_m=decl["wave_m"],
warp_n=decl["wave_n"],
warp_k=decl.get("wave_k", 1),
warp_tile_m=decl["warp_m"],
warp_tile_n=decl["warp_n"],
warp_tile_k=decl["warp_k"],
)
trait = GroupedConvTraitConfig(
pipeline=pipeline,
scheduler=decl["scheduler"],
epilogue=decl.get("epilogue", "cshuffle"),
double_smem_buffer=decl.get("double_smem_buffer", False),
pad_m=decl.get("pad_m", True),
pad_n=decl.get("pad_n", True),
pad_k=decl.get("pad_k", True),
num_groups_to_merge=decl.get("num_groups_to_merge", 1),
)
config = GroupedConvKernelConfig(
tile=tile,
trait=trait,
variant=variant,
ndim_spatial=decl["num_dims"],
arch=decl.get("arch", gpu_target),
vector_size_a=decl.get("vector_a", 4),
vector_size_b=decl.get("vector_b", 8),
vector_size_c=decl.get("vector_c", 8),
block_per_cu=decl.get("block_per_cu", 1),
num_wave_groups=decl.get("num_wave_groups", 1),
num_groups_to_merge=decl.get("num_groups_to_merge", 1),
double_smem_buffer=decl.get("double_smem_buffer", False),
)
codegen = UnifiedGroupedConvCodegen(output_dir, gpu_target=gpu_target)
kernel_path, _ = codegen.generate_kernel(config, decl["dtype"], variant)
return (idx, str(kernel_path), None)
except Exception as e:
return (idx, None, str(e))
def generate_grouped_conv_kernels(
declarations: list,
output_dir: Path,
gpu_target: str = "gfx942",
max_workers: Optional[int] = None,
) -> list:
"""Generate grouped convolution kernels using unified_grouped_conv_codegen.
Uses ProcessPoolExecutor for parallel kernel generation.
"""
output_dir.mkdir(parents=True, exist_ok=True)
# Prepare work items (add _idx for ordering)
work_items = []
for idx, decl in enumerate(declarations):
decl_copy = decl.copy()
decl_copy["_idx"] = idx
work_items.append((decl_copy, str(output_dir), gpu_target))
max_workers = max_workers or min(len(work_items), os.cpu_count() or 4)
generated = []
failed = []
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(_generate_single_grouped_conv_kernel, w): w[0]["_idx"]
for w in work_items
}
for future in as_completed(futures):
idx, path, err = future.result()
if path:
generated.append(Path(path))
print_info(f" Generated: {Path(path).name}")
else:
failed.append((idx, err))
print_error(f" Failed kernel {idx + 1}: {err}")
if failed:
for idx, err in failed[:3]:
print_error(f" Kernel {idx + 1}: {err[:200]}")
if len(failed) > 3:
print_error(f" ... and {len(failed) - 3} more failures")
return generated
def compile_grouped_conv_example(
source_file: Path,
output_bin: Path,
kernel_headers: list,
hipcc: str,
gpu_target: str,
) -> bool:
"""Compile the C++ example with generated kernels."""
kernel_dir = get_generated_kernels_dir()
ck_root = get_ck_root()
dispatcher_dir = get_dispatcher_root()
includes = [
f"-I{ck_root / 'include'}",
f"-I{dispatcher_dir / 'include'}",
f"-I{kernel_dir}",
]
# Build include flags for generated kernels
kernel_includes = []
for header in kernel_headers:
kernel_includes.extend(["-include", str(header)])
# Add define to indicate kernels are available
defines = ["-DGROUPED_CONV_KERNEL_AVAILABLE=1"]
cmd = [
hipcc,
"-std=c++20",
"-O2",
f"--offload-arch={gpu_target}",
*includes,
*defines,
*kernel_includes,
"-o",
str(output_bin),
str(source_file),
]
print_info(f" Compiling: {source_file.name}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
if result.stderr:
lines = result.stderr.split("\n")
errors = [line for line in lines if "error:" in line.lower()][:5]
for err_line in errors:
print_error(f" {err_line}")
return False
return True
def main():
parser = argparse.ArgumentParser(
description="Build C++ grouped convolution example with self-contained kernel generation"
)
parser.add_argument("source", help="Source file (.cpp)")
parser.add_argument("--output", "-o", help="Output binary name")
parser.add_argument("--gpu-target", default="gfx942", help="GPU target")
parser.add_argument(
"--no-compile", action="store_true", help="Only generate kernels, don't compile"
)
parser.add_argument("--verbose", "-v", action="store_true")
parser.add_argument(
"--jobs",
"-j",
type=int,
default=None,
help="Parallel jobs for kernel generation (default: cpu_count)",
)
args = parser.parse_args()
# Resolve source file
source_file = Path(args.source)
if not source_file.is_absolute():
candidates = [
get_dispatcher_root() / args.source,
Path.cwd() / args.source,
]
for c in candidates:
if c.exists():
source_file = c
break
if not source_file.exists():
print_error(f"Source file not found: {source_file}")
return 1
build_dir = get_build_dir()
kernel_dir = get_generated_kernels_dir()
output_name = args.output or source_file.stem
output_bin = build_dir / output_name
print_success("=== Grouped Conv Example Builder (Self-Contained) ===")
# Phase 1: Extract declarations
print_phase(1, "Scanning for DECL_GROUPED_CONV_KERNEL_SET...")
declarations = extract_grouped_conv_declarations(source_file)
if not declarations:
print_error(" No DECL_GROUPED_CONV_KERNEL_SET declarations found!")
return 1
print(f" Found {len(declarations)} kernel declaration(s):")
for decl in declarations:
name = f"{decl['dtype']}_{decl['conv_type']}_{decl['num_dims']}d_{decl['tile_k']}x{decl['tile_c']}"
print(f" [{decl['set']}] {name}")
# Phase 2: Validate and expand
print_phase(2, "Validating and expanding declarations...")
declarations = validate_and_expand_grouped_conv_declarations(
declarations, args.gpu_target, args.verbose
)
print()
# Phase 3: Generate kernels
print_phase(3, "Generating kernels...")
generated = generate_grouped_conv_kernels(
declarations, kernel_dir, args.gpu_target, max_workers=args.jobs
)
if not generated:
print_error(" No kernels generated!")
return 1
print(f" Generated {len(generated)} kernel file(s)")
print()
# Phase 4: Compile (optional)
if args.no_compile:
print_info("Skipping compilation (--no-compile)")
print()
print_success("=== Kernel Generation Complete ===")
print(f"Kernels in: {kernel_dir}")
return 0
print_phase(4, "Compiling example...")
hipcc_path = find_hipcc()
if not hipcc_path:
print_error(" hipcc not found. Install ROCm or set HIPCC env var.")
print(" To compile manually:")
ck_root = get_dispatcher_root().parent
print(
f" hipcc -std=c++20 -O2 -I{ck_root / 'include'} -I{get_dispatcher_root() / 'include'} \\"
)
print(f" -I{kernel_dir} \\")
for h in generated[:1]:
print(f" -include {h} \\")
print(" -DGROUPED_CONV_KERNEL_AVAILABLE=1 \\")
print(f" --offload-arch={args.gpu_target} \\")
print(f" {source_file} -o {output_bin}")
return 1
build_dir.mkdir(parents=True, exist_ok=True)
if not compile_grouped_conv_example(
source_file, output_bin, generated, hipcc_path, args.gpu_target
):
print_error(" Compilation failed!")
return 1
print_success(f" Output: {output_bin}")
print()
print_success("=== Build Complete ===")
print()
print("Run with:")
print(f" {output_bin}")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -55,10 +55,10 @@ def extract_balanced_parens(text: str, start_pos: int) -> str:
def parse_conv_declarations(content: str) -> List[Dict]:
"""Parse DECL_CONV_KERNEL_SET declarations with all parameters."""
"""Parse DECL_GROUPED_CONV_KERNEL_SET declarations with all parameters."""
kernels = []
for match in re.finditer(r"DECL_CONV_KERNEL_SET\s*\(", content):
for match in re.finditer(r"DECL_GROUPED_CONV_KERNEL_SET\s*\(", content):
body = extract_balanced_parens(content, match.end() - 1)
if not body:
continue
@@ -619,7 +619,7 @@ def strip_cpp_strings_and_comments(content: str) -> str:
n = len(content)
# Patterns that indicate a string is problematic and should be stripped
problematic_patterns = ["DECL_KERNEL_SET", "DECL_CONV_KERNEL_SET", ".add("]
problematic_patterns = ["DECL_KERNEL_SET", "DECL_GROUPED_CONV_KERNEL_SET", ".add("]
while i < n:
# Check for raw string literal: R"delimiter(...)delimiter"
@@ -697,7 +697,7 @@ def detect_and_parse(source_path: Path) -> Tuple[str, List[Dict]]:
content = source_path.read_text()
content = strip_cpp_strings_and_comments(content)
if "DECL_CONV_KERNEL_SET" in content:
if "DECL_GROUPED_CONV_KERNEL_SET" in content:
return "conv", parse_conv_declarations(content)
elif "DECL_KERNEL_SET" in content:
return "gemm", parse_gemm_declarations(content)
@@ -966,30 +966,128 @@ def generate_per_set_functions(source_stem: str) -> str:
def generate_conv_registration(
kernel_headers: List[Path], example_name: str, kernels: List[Dict]
) -> str:
"""Generate Conv kernel registration code for the dispatcher registry."""
"""Generate Conv kernel registration code for the dispatcher registry.
Creates real GroupedConvKernelInstance entries backed by the generated
launcher's launch() method via the conv backend RunFn factories.
"""
if not kernel_headers:
return " // No kernels to register"
lines = []
lines.append(
" (void)registry; (void)arch; // Conv uses direct launcher pattern for now"
)
# For conv, we provide direct access to kernel launchers
for i, h in enumerate(kernel_headers):
kernel_name = h.stem
lines.append(f" // Kernel {i + 1}: {kernel_name}")
kname = h.stem
ns = f"ns_{kname}"
launcher = f"{ns}::{kname}_Launcher"
# Determine direction and ndim from the kernel header name
if "_fwd_" in kname:
direction = "Forward"
run_fn_factory = "make_conv_fwd_run_fn"
elif "_bwd_data_" in kname or "_bwdd_" in kname:
direction = "BackwardData"
run_fn_factory = "make_conv_bwd_data_run_fn"
elif "_bwd_weight_" in kname or "_bwdw_" in kname:
direction = "BackwardWeight"
run_fn_factory = "make_conv_bwd_weight_run_fn"
else:
direction = "Forward"
run_fn_factory = "make_conv_fwd_run_fn"
ndim = 3 if "_3d_" in kname else 2
# Parse dtype from name (e.g. grouped_conv_fwd_fp16_...)
dtype = "fp16"
for dt in ["fp16", "bf16", "fp32"]:
if f"_{dt}_" in kname:
dtype = dt
break
# Parse tile, wave, warp from name.
# Format: ..._TILExTILExTILE_WAVExWAVExWAVE_WARPxWARPxWARP_...
import re as _re
tile_m, tile_n, tile_k = 1, 128, 128
wave_m, wave_n, wave_k = 2, 2, 1
warp_m, warp_n, warp_k = 32, 32, 16
triplets = _re.findall(r"_(\d+)x(\d+)x(\d+)", kname)
if len(triplets) >= 1:
tile_m, tile_n, tile_k = (
int(triplets[0][0]),
int(triplets[0][1]),
int(triplets[0][2]),
)
if len(triplets) >= 2:
wave_m, wave_n, wave_k = (
int(triplets[1][0]),
int(triplets[1][1]),
int(triplets[1][2]),
)
if len(triplets) >= 3:
warp_m, warp_n, warp_k = (
int(triplets[2][0]),
int(triplets[2][1]),
int(triplets[2][2]),
)
pipeline = "compv4" if "compv4" in kname else "compv3"
scheduler = "interwave" if "interwave" in kname else "intrawave"
epilogue = "cshuffle" if "cshuffle" in kname else "default"
# ConvConfigBase defaults
vec_a, vec_b, vec_c = 4, 8, 8
block_per_cu = 1
num_wave_groups = 1
num_groups_to_merge = 1
lines.append(f" // Kernel {i + 1}: {kname}")
lines.append(" {")
lines.append(f" ck_tile::dispatcher::GroupedConvKernelKey key_{i};")
lines.append(f' key_{i}.dtype_in = "{dtype}";')
lines.append(f' key_{i}.dtype_wei = "{dtype}";')
lines.append(f' key_{i}.dtype_out = "{dtype}";')
lines.append(f' key_{i}.layout = "nhwgc";')
lines.append(f" key_{i}.ndim_spatial = {ndim};")
lines.append(
f" key_{i}.op = ck_tile::dispatcher::GroupedConvOp::{direction};"
)
lines.append(f" key_{i}.tile_m = {tile_m};")
lines.append(f" key_{i}.tile_n = {tile_n};")
lines.append(f" key_{i}.tile_k = {tile_k};")
lines.append(f" key_{i}.wave_m = {wave_m};")
lines.append(f" key_{i}.wave_n = {wave_n};")
lines.append(f" key_{i}.wave_k = {wave_k};")
lines.append(f" key_{i}.warp_m = {warp_m};")
lines.append(f" key_{i}.warp_n = {warp_n};")
lines.append(f" key_{i}.warp_k = {warp_k};")
lines.append(f' key_{i}.pipeline = "{pipeline}";')
lines.append(f' key_{i}.scheduler = "{scheduler}";')
lines.append(f' key_{i}.epilogue = "{epilogue}";')
lines.append(f" key_{i}.vector_size_a = {vec_a};")
lines.append(f" key_{i}.vector_size_b = {vec_b};")
lines.append(f" key_{i}.vector_size_c = {vec_c};")
lines.append(f" key_{i}.block_per_cu = {block_per_cu};")
lines.append(f" key_{i}.num_wave_groups = {num_wave_groups};")
lines.append(f" key_{i}.num_groups_to_merge = {num_groups_to_merge};")
lines.append(f" key_{i}.arch = arch;")
lines.append(
f" auto run_fn_{i} = ck_tile::dispatcher::backends::{run_fn_factory}<{launcher}, {ndim}>();"
)
lines.append(
f' auto inst_{i} = std::make_shared<ck_tile::dispatcher::GroupedConvKernelInstance>(key_{i}, "{kname}", std::move(run_fn_{i}));'
)
lines.append(f" registry.register_kernel(key_{i}, inst_{i});")
lines.append(" }")
return "\n".join(lines)
def generate_conv_kernels(
kernels: List[Dict], output_dir: Path, codegen_dir: Path
) -> bool:
"""Generate Conv kernels for ALL declarations using unified codegen."""
if not kernels:
return False
def _build_conv_codegen_cmd(
idx: int, k: Dict, codegen_dir: Path, output_dir: Path
) -> Tuple[int, List[str], str]:
"""Build the command for a single conv kernel codegen invocation."""
variant_map = {
"forward": "forward",
"bwd_data": "bwd_data",
@@ -997,93 +1095,130 @@ def generate_conv_kernels(
"bwd_weight": "bwd_weight",
"backward_weight": "bwd_weight",
}
variant = variant_map.get(k.get("conv_type", "forward"), "forward")
cmd = [
sys.executable,
str(codegen_dir / "unified_grouped_conv_codegen.py"),
"--datatype",
k.get("dtype", "fp16"),
"--variant",
variant,
"--ndim",
str(k.get("ndim", 2)),
"--output",
str(output_dir),
]
if k.get("tile_m"):
cmd.extend(["--tile-m", str(k["tile_m"])])
if k.get("tile_n"):
cmd.extend(["--tile-n", str(k["tile_n"])])
if k.get("warp_m"):
cmd.extend(["--warp-m", str(k["warp_m"])])
if k.get("warp_n"):
cmd.extend(["--warp-n", str(k["warp_n"])])
if k.get("warp_k"):
cmd.extend(["--warp-k", str(k["warp_k"])])
if k.get("warp_tile_m"):
cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])])
if k.get("warp_tile_n"):
cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])])
if k.get("warp_tile_k"):
cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])])
if k.get("pipeline"):
cmd.extend(["--pipeline", k["pipeline"]])
if k.get("scheduler"):
cmd.extend(["--scheduler", k["scheduler"]])
if k.get("epilogue"):
cmd.extend(["--epilogue", k["epilogue"]])
if k.get("vector_a"):
cmd.extend(["--vector-a", str(k["vector_a"])])
if k.get("vector_b"):
cmd.extend(["--vector-b", str(k["vector_b"])])
if k.get("vector_c"):
cmd.extend(["--vector-c", str(k["vector_c"])])
if k.get("block_per_cu"):
cmd.extend(["--block-per-cu", str(k["block_per_cu"])])
if k.get("num_wave_groups"):
cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])])
if k.get("num_groups_to_merge"):
cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])])
if k.get("double_smem_buffer") is not None:
cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()])
if k.get("tile_k"):
cmd.extend(["--tile-k", str(k["tile_k"])])
return (idx, cmd, str(codegen_dir))
def _run_conv_codegen(args: Tuple) -> Tuple[int, bool, str]:
"""Run unified_grouped_conv_codegen.py for a single kernel config (picklable for ProcessPoolExecutor)."""
idx, cmd, cwd = args
result = subprocess.run(cmd, capture_output=True, text=True, cwd=cwd)
if result.returncode != 0:
return (idx, False, result.stderr[:300])
return (idx, True, "")
def generate_conv_kernels(
kernels: List[Dict], output_dir: Path, codegen_dir: Path
) -> bool:
"""Generate Conv kernels for ALL declarations using unified codegen.
Launches all codegen subprocesses in parallel via ProcessPoolExecutor
for significantly faster generation when multiple conv kernels are declared.
"""
if not kernels:
return False
work_items = [
_build_conv_codegen_cmd(idx, k, codegen_dir, output_dir)
for idx, k in enumerate(kernels)
]
success_count = 0
max_workers = min(len(work_items), os.cpu_count() or 4)
# Generate a kernel for EACH declaration
for idx, k in enumerate(kernels):
variant = variant_map.get(k.get("conv_type", "forward"), "forward")
cmd = [
sys.executable,
str(codegen_dir / "unified_conv_codegen.py"),
"--datatype",
k.get("dtype", "fp16"),
"--variant",
variant,
"--ndim",
str(k.get("ndim", 2)),
"--output",
str(output_dir),
]
# Add optional parameters if specified
if k.get("tile_m"):
cmd.extend(["--tile-m", str(k["tile_m"])])
if k.get("tile_n"):
cmd.extend(["--tile-n", str(k["tile_n"])])
if k.get("warp_m"):
cmd.extend(["--warp-m", str(k["warp_m"])])
if k.get("warp_n"):
cmd.extend(["--warp-n", str(k["warp_n"])])
if k.get("warp_k"):
cmd.extend(["--warp-k", str(k["warp_k"])])
if k.get("warp_tile_m"):
cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])])
if k.get("warp_tile_n"):
cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])])
if k.get("warp_tile_k"):
cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])])
if k.get("pipeline"):
cmd.extend(["--pipeline", k["pipeline"]])
if k.get("scheduler"):
cmd.extend(["--scheduler", k["scheduler"]])
if k.get("epilogue"):
cmd.extend(["--epilogue", k["epilogue"]])
if k.get("vector_a"):
cmd.extend(["--vector-a", str(k["vector_a"])])
if k.get("vector_b"):
cmd.extend(["--vector-b", str(k["vector_b"])])
if k.get("vector_c"):
cmd.extend(["--vector-c", str(k["vector_c"])])
if k.get("block_per_cu"):
cmd.extend(["--block-per-cu", str(k["block_per_cu"])])
if k.get("num_wave_groups"):
cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])])
if k.get("num_groups_to_merge"):
cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])])
if k.get("double_smem_buffer") is not None:
cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()])
if k.get("tile_k"):
cmd.extend(["--tile-k", str(k["tile_k"])])
result = subprocess.run(
cmd, capture_output=True, text=True, cwd=str(codegen_dir)
)
if result.returncode != 0:
print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}")
else:
success_count += 1
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(_run_conv_codegen, w): w[0] for w in work_items}
for future in as_completed(futures):
idx, ok, err = future.result()
if ok:
success_count += 1
else:
print(f" Codegen error for kernel {idx + 1}: {err}")
return success_count > 0
def _run_gemm_codegen(args: Tuple) -> Tuple[int, bool, str]:
"""Run unified_gemm_codegen.py for a single kernel config (picklable for ProcessPoolExecutor)."""
idx, cmd, cwd = args
result = subprocess.run(cmd, capture_output=True, text=True, cwd=cwd)
if result.returncode != 0:
return (idx, False, result.stderr[:300])
return (idx, True, "")
def generate_gemm_kernels(
kernels: List[Dict], output_dir: Path, codegen_dir: Path
) -> bool:
"""Generate GEMM kernels for ALL declarations using unified codegen."""
"""Generate GEMM kernels for ALL declarations using unified codegen.
Launches all codegen subprocesses in parallel via ProcessPoolExecutor
for significantly faster generation when multiple kernels are declared.
"""
import json
if not kernels:
return False
success_count = 0
# Generate a kernel for EACH declaration
# Build all commands upfront
work_items = []
for idx, k in enumerate(kernels):
variant = "multi_d" if k.get("elementwise_op") else "standard"
# Build tile config JSON for this specific kernel
tile_config = {
"tile_m": [k.get("tile_m", 128)],
"tile_n": [k.get("tile_n", 128)],
@@ -1125,13 +1260,20 @@ def generate_gemm_kernels(
config_json,
]
result = subprocess.run(
cmd, capture_output=True, text=True, cwd=str(codegen_dir)
)
if result.returncode != 0:
print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}")
else:
success_count += 1
work_items.append((idx, cmd, str(codegen_dir)))
# Run all codegen subprocesses in parallel
success_count = 0
max_workers = min(len(work_items), os.cpu_count() or 4)
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(_run_gemm_codegen, w): w[0] for w in work_items}
for future in as_completed(futures):
idx, ok, err = future.result()
if ok:
success_count += 1
else:
print(f" Codegen error for kernel {idx + 1}: {err}")
return success_count > 0
@@ -1229,15 +1371,17 @@ def main():
if example_type == "gemm":
kernel_headers = list(args.output_dir.glob("gemm_*.hpp"))
else:
k = kernels[0] if kernels else {}
variant = k.get("conv_type", "forward")
prefix_map = {
"forward": "conv_fwd",
"bwd_data": "conv_bwdd",
"bwd_weight": "conv_bwdw",
"forward": "grouped_conv_fwd",
"bwd_data": "grouped_conv_bwd_data",
"bwd_weight": "grouped_conv_bwd_weight",
}
prefix = prefix_map.get(variant, "conv_fwd")
kernel_headers = list(args.output_dir.glob(f"{prefix}_*.hpp"))
# Collect headers from ALL variants present in declarations
variants_used = set(k.get("conv_type", "forward") for k in kernels)
kernel_headers = []
for variant in variants_used:
prefix = prefix_map.get(variant, "grouped_conv_fwd")
kernel_headers.extend(args.output_dir.glob(f"{prefix}_*.hpp"))
if not kernel_headers:
print(f"[{target_name}] No kernel headers generated!")
@@ -1347,29 +1491,39 @@ def main():
)
if has_bwd_data:
bwdd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdd_")
if bwdd_kernel:
bwdd_ns = f"ns_{bwdd_kernel.stem}"
launcher_aliases.append(
f"using BwdDataKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;"
bwd_data_kernel = find_kernel_by_dtype_type(
kernel_headers, "fp16", "_bwd_data_"
)
if not bwd_data_kernel:
bwd_data_kernel = find_kernel_by_dtype_type(
kernel_headers, "fp16", "_bwdd_"
)
if not has_fwd: # If no fwd, use bwd_data as first
if bwd_data_kernel:
bwd_data_ns = f"ns_{bwd_data_kernel.stem}"
launcher_aliases.append(
f"using BwdDataKernelLauncher = {bwd_data_ns}::{bwd_data_kernel.stem}_Launcher;"
)
if not has_fwd:
launcher_aliases.append(
f"using FirstKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;"
f"using FirstKernelLauncher = {bwd_data_ns}::{bwd_data_kernel.stem}_Launcher;"
)
if has_bwd_weight:
bwdw_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdw_")
if bwdw_kernel:
bwdw_ns = f"ns_{bwdw_kernel.stem}"
launcher_aliases.append(
f"using BwdWeightKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;"
bwd_weight_kernel = find_kernel_by_dtype_type(
kernel_headers, "fp16", "_bwd_weight_"
)
if not bwd_weight_kernel:
bwd_weight_kernel = find_kernel_by_dtype_type(
kernel_headers, "fp16", "_bwdw_"
)
if (
not has_fwd and not has_bwd_data
): # If no fwd or bwdd, use bwdw as first
if bwd_weight_kernel:
bwd_weight_ns = f"ns_{bwd_weight_kernel.stem}"
launcher_aliases.append(
f"using BwdWeightKernelLauncher = {bwd_weight_ns}::{bwd_weight_kernel.stem}_Launcher;"
)
if not has_fwd and not has_bwd_data:
launcher_aliases.append(
f"using FirstKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;"
f"using FirstKernelLauncher = {bwd_weight_ns}::{bwd_weight_kernel.stem}_Launcher;"
)
launcher_section = "\n".join(launcher_aliases)
@@ -1382,14 +1536,16 @@ def main():
#include "ck_tile/dispatcher/registry.hpp"
#include "ck_tile/dispatcher/kernel_instance.hpp"
#include "ck_tile/dispatcher/kernel_key.hpp"
#include "ck_tile/dispatcher/grouped_conv_registry.hpp"
#include "ck_tile/dispatcher/backends/generated_conv_backend.hpp"
namespace generated {{
// Kernel launchers for direct use
{launcher_section}
// Registration function
inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{
// Registration function (takes GroupedConvRegistry for conv kernels)
inline void {func_name}(ck_tile::dispatcher::GroupedConvRegistry& registry, const std::string& arch) {{
{register_body}
}}
@@ -1439,7 +1595,7 @@ inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::stri
"""
header_path.write_text(header_content)
print(f"[{target_name}] {len(obj_files)} kernels compiled")
print(f"[{target_name}] OK {len(obj_files)} kernels compiled")
return 0

View File

@@ -0,0 +1,107 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""Generate the conv_python_dispatch.hpp header for the Python conv library.
Reads the include_all headers to find available kernels and creates dispatch
aliases for 2D/3D x fwd/bwd_data/bwd_weight.
"""
import argparse
import re
from pathlib import Path
def find_3d_launcher(include_all_path: Path, variant_prefix: str) -> str:
"""Find first 3D launcher name from an include_all header."""
text = include_all_path.read_text()
pattern = rf"(grouped_conv_{variant_prefix}_\w+_3d_\w+)\.hpp"
match = re.search(pattern, text)
if match:
return match.group(1) + "_Launcher"
return ""
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--kernel-dir", required=True)
parser.add_argument("--output", required=True)
args = parser.parse_args()
kdir = Path(args.kernel_dir)
fwd_3d = find_3d_launcher(kdir / "include_all_grouped_conv_fwd_kernels.hpp", "fwd")
bwd_data_3d = find_3d_launcher(
kdir / "include_all_grouped_conv_bwd_data_kernels.hpp", "bwd_data"
)
bwd_weight_3d = find_3d_launcher(
kdir / "include_all_grouped_conv_bwd_weight_kernels.hpp", "bwd_weight"
)
lines = [
"// Auto-generated dispatch header for Python conv library",
"#pragma once",
"",
"// Forward kernels",
'#include "include_all_grouped_conv_fwd_kernels.hpp"',
"#define CONV_FWD_2D_AVAILABLE 1",
]
if fwd_3d:
lines += [
"#define CONV_FWD_3D_AVAILABLE 1",
f"using ConvFwd3dLauncher = {fwd_3d};",
]
lines += [
"",
"// Backward data kernels",
'#include "include_all_grouped_conv_bwd_data_kernels.hpp"',
"#define CONV_BWD_DATA_2D_AVAILABLE 1",
]
if bwd_data_3d:
lines += [
"#define CONV_BWD_DATA_3D_AVAILABLE 1",
f"using ConvBwdData3dLauncher = {bwd_data_3d};",
]
lines += [
"",
"// Backward weight kernels",
'#include "include_all_grouped_conv_bwd_weight_kernels.hpp"',
"#define CONV_BWD_WEIGHT_2D_AVAILABLE 1",
]
if bwd_weight_3d:
lines += [
"#define CONV_BWD_WEIGHT_3D_AVAILABLE 1",
f"using ConvBwdWeight3dLauncher = {bwd_weight_3d};",
]
# Kernel name table for Python introspection
names = []
if True: # fwd 2D always present
names.append('"fwd_2d"')
if fwd_3d:
names.append('"fwd_3d"')
if True: # bwd_data 2D
names.append('"bwd_data_2d"')
if bwd_data_3d:
names.append('"bwd_data_3d"')
if True: # bwd_weight 2D
names.append('"bwd_weight_2d"')
if bwd_weight_3d:
names.append('"bwd_weight_3d"')
lines += [
"",
"// Kernel inventory for Python",
f"static const char* CONV_KERNEL_NAMES[] = {{{', '.join(names)}}};",
f"static const int CONV_KERNEL_COUNT = {len(names)};",
"",
]
Path(args.output).write_text("\n".join(lines) + "\n")
print(f"Generated dispatch header: {args.output} ({len(names)} kernels)")
if __name__ == "__main__":
main()

View File

@@ -132,7 +132,7 @@ def main():
print(f"Linking failed: {result.stderr}")
return 1
print(f" Built: {lib_path}")
print(f"OK Built: {lib_path}")
return 0

View File

@@ -34,9 +34,9 @@ from compile_gemm_examples import ( # noqa: E402
validate_kernel_config,
expand_declaration_with_arch_filter,
)
from compile_conv_examples import ( # noqa: E402
validate_conv_kernel_config,
expand_conv_declaration_with_arch_filter,
from compile_grouped_conv_examples import ( # noqa: E402
validate_grouped_conv_kernel_config as validate_conv_kernel_config,
expand_grouped_conv_declaration_with_arch_filter as expand_conv_declaration_with_arch_filter,
)
@@ -316,7 +316,7 @@ def test_python_autocorrect(verbose=False):
if was_modified:
print(f" Modified: {len(corrections)} correction(s)")
for c in corrections:
print(f" {c}")
print(f" - {c}")
except Exception as e:
results["failed"] += 1
@@ -465,7 +465,7 @@ def run_stress_test(arch, num_samples, verbose):
}
expanded = expand_declaration_with_arch_filter(config, test_arch)
status = "" if expanded else ""
status = "OK" if expanded else "FAIL"
expected = test_arch in test["expected_archs"]
match = "OK" if (bool(expanded) == expected) else "MISMATCH"

View File

@@ -2,17 +2,18 @@
// SPDX-License-Identifier: MIT
#include "ck_tile/dispatcher/dispatcher.hpp"
#include <stdexcept>
#include "ck_tile/dispatcher/dispatcher_error.hpp"
#include <sstream>
#include <iostream>
namespace ck_tile {
namespace dispatcher {
Dispatcher::Dispatcher(Registry* registry)
Dispatcher::Dispatcher(Registry* registry, const std::string& gfx_arch)
: registry_(registry ? registry : &Registry::instance()),
heuristic_(nullptr),
strategy_(SelectionStrategy::FirstFit)
strategy_(SelectionStrategy::FirstFit),
gfx_arch_(gfx_arch)
{
}
@@ -61,7 +62,7 @@ float Dispatcher::run_fused(const void* a_ptr,
std::ostringstream oss;
oss << "No suitable kernel found for problem: M=" << problem.M << " N=" << problem.N
<< " K=" << problem.K;
throw std::runtime_error(oss.str());
throw NoKernelFound(oss.str());
}
return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream);
@@ -78,7 +79,7 @@ float Dispatcher::run_explicit(const std::string& kernel_id,
auto kernel = registry_->lookup(kernel_id);
if(!kernel)
{
throw std::runtime_error("Kernel not found: " + kernel_id);
throw NoKernelFound("Kernel not found: " + kernel_id);
}
if(!kernel->supports(problem))
@@ -86,7 +87,7 @@ float Dispatcher::run_explicit(const std::string& kernel_id,
std::ostringstream oss;
oss << "Kernel " << kernel_id << " does not support problem: M=" << problem.M
<< " N=" << problem.N << " K=" << problem.K;
throw std::runtime_error(oss.str());
throw UnsupportedProblem(oss.str());
}
return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream);

View File

@@ -5,39 +5,32 @@
#include "ck_tile/dispatcher/json_export.hpp"
#include "ck_tile/dispatcher/arch_filter.hpp"
#include <algorithm>
#include <fstream>
#include <iostream>
namespace ck_tile {
namespace dispatcher {
Registry::Registry()
: name_("default"),
auto_export_enabled_(false),
auto_export_include_statistics_(true),
auto_export_on_every_registration_(true)
{
}
Registry::Registry() = default;
Registry::~Registry()
{
// Perform auto-export on destruction if enabled (regardless of export_on_every_registration
// setting)
if(auto_export_enabled_)
{
perform_auto_export();
}
}
Registry::Registry(Registry&& other) noexcept
: mutex_() // mutex is not movable, create new one
,
kernels_(std::move(other.kernels_)),
name_(std::move(other.name_)),
auto_export_enabled_(other.auto_export_enabled_),
auto_export_filename_(std::move(other.auto_export_filename_)),
auto_export_include_statistics_(other.auto_export_include_statistics_),
auto_export_on_every_registration_(other.auto_export_on_every_registration_)
Registry::Registry(Registry&& other) noexcept : Base(std::move(other))
{
// Disable auto-export on the moved-from object to prevent double export
// Base move constructor already locked+released other.mutex_.
// Re-acquire to safely read the remaining fields.
std::lock_guard<std::mutex> lock(other.mutex());
auto_export_enabled_ = other.auto_export_enabled_;
auto_export_filename_ = std::move(other.auto_export_filename_);
auto_export_include_statistics_ = other.auto_export_include_statistics_;
auto_export_on_every_registration_ = other.auto_export_on_every_registration_;
other.auto_export_enabled_ = false;
}
@@ -45,11 +38,7 @@ Registry& Registry::operator=(Registry&& other) noexcept
{
if(this != &other)
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> other_lock(other.mutex_);
kernels_ = std::move(other.kernels_);
name_ = std::move(other.name_);
Base::operator=(std::move(other));
auto_export_enabled_ = other.auto_export_enabled_;
auto_export_filename_ = std::move(other.auto_export_filename_);
auto_export_include_statistics_ = other.auto_export_include_statistics_;
@@ -64,55 +53,27 @@ Registry& Registry::operator=(Registry&& other) noexcept
bool Registry::register_kernel(KernelInstancePtr instance, Priority priority)
{
if(!instance)
{
return false;
}
const std::string identifier = instance->get_key().encode_identifier();
bool registered = false;
if(Base::register_kernel(instance->get_name(), instance, priority))
{
std::lock_guard<std::mutex> lock(mutex_);
auto it = kernels_.find(identifier);
if(it != kernels_.end())
if(auto_export_enabled_ && auto_export_on_every_registration_)
{
// Kernel with this identifier already exists
// Only replace if new priority is higher
if(priority > it->second.priority)
{
it->second.instance = instance;
it->second.priority = priority;
registered = true;
}
}
else
{
// New kernel, insert it
kernels_[identifier] = RegistryEntry{instance, priority};
registered = true;
perform_auto_export();
}
return true;
}
// Perform auto-export if enabled and configured to export on every registration
if(registered && auto_export_enabled_ && auto_export_on_every_registration_)
{
perform_auto_export();
}
return registered;
return false;
}
KernelInstancePtr Registry::lookup(const std::string& identifier) const
{
std::lock_guard<std::mutex> lock(mutex_);
auto it = kernels_.find(identifier);
if(it != kernels_.end())
std::lock_guard<std::mutex> lock(mutex());
auto it = entries().find(identifier);
if(it != entries().end())
{
return it->second.instance;
}
return nullptr;
}
@@ -121,75 +82,23 @@ KernelInstancePtr Registry::lookup(const KernelKey& key) const
return lookup(key.encode_identifier());
}
std::vector<KernelInstancePtr> Registry::get_all() const
{
std::lock_guard<std::mutex> lock(mutex_);
std::vector<KernelInstancePtr> result;
result.reserve(kernels_.size());
for(const auto& pair : kernels_)
{
result.push_back(pair.second.instance);
}
return result;
}
std::vector<KernelInstancePtr> Registry::get_all() const { return Base::get_all_instances(); }
std::vector<KernelInstancePtr>
Registry::filter(std::function<bool(const KernelInstance&)> predicate) const
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(mutex());
std::vector<KernelInstancePtr> result;
for(const auto& pair : kernels_)
for(const auto& [name, entry] : entries())
{
if(predicate(*pair.second.instance))
if(predicate(*(entry.instance)))
{
result.push_back(pair.second.instance);
result.push_back(entry.instance);
}
}
return result;
}
std::size_t Registry::size() const
{
std::lock_guard<std::mutex> lock(mutex_);
return kernels_.size();
}
bool Registry::empty() const
{
std::lock_guard<std::mutex> lock(mutex_);
return kernels_.empty();
}
void Registry::clear()
{
std::lock_guard<std::mutex> lock(mutex_);
kernels_.clear();
}
const std::string& Registry::get_name() const
{
std::lock_guard<std::mutex> lock(mutex_);
return name_;
}
void Registry::set_name(const std::string& name)
{
std::lock_guard<std::mutex> lock(mutex_);
name_ = name;
}
Registry& Registry::instance()
{
static Registry global_registry;
return global_registry;
}
std::string Registry::export_json(bool include_statistics) const
{
return export_registry_json(*this, include_statistics);
@@ -204,7 +113,7 @@ void Registry::enable_auto_export(const std::string& filename,
bool include_statistics,
bool export_on_every_registration)
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(mutex());
auto_export_enabled_ = true;
auto_export_filename_ = filename;
auto_export_include_statistics_ = include_statistics;
@@ -213,13 +122,13 @@ void Registry::enable_auto_export(const std::string& filename,
void Registry::disable_auto_export()
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(mutex());
auto_export_enabled_ = false;
}
bool Registry::is_auto_export_enabled() const
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(mutex());
return auto_export_enabled_;
}
@@ -230,7 +139,7 @@ void Registry::perform_auto_export()
bool include_stats;
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(mutex());
if(!auto_export_enabled_)
{
return;
@@ -243,31 +152,15 @@ void Registry::perform_auto_export()
export_json_to_file(filename, include_stats);
}
std::size_t Registry::merge_from(const Registry& other, Priority priority)
{
auto other_kernels = other.get_all();
std::size_t merged_count = 0;
for(const auto& kernel : other_kernels)
{
if(register_kernel(kernel, priority))
{
merged_count++;
}
}
return merged_count;
}
std::size_t Registry::filter_by_arch(const std::string& gpu_arch)
{
ArchFilter filter(gpu_arch);
std::vector<std::string> to_remove;
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(mutex());
for(const auto& pair : kernels_)
for(const auto& pair : entries())
{
if(!filter.is_valid(pair.second.instance->get_key()))
{
@@ -277,12 +170,18 @@ std::size_t Registry::filter_by_arch(const std::string& gpu_arch)
for(const auto& key : to_remove)
{
kernels_.erase(key);
entries_mut().erase(key);
}
}
return to_remove.size();
}
Registry& Registry::instance()
{
static Registry global_registry;
return global_registry;
}
} // namespace dispatcher
} // namespace ck_tile
} // namespace ck_tile

View File

@@ -217,6 +217,10 @@ endforeach()
# Standalone integration tests (with their own main())
set(STANDALONE_TESTS
test_minimal.cpp
test_grouped_conv_config.cpp
test_grouped_conv_problem.cpp
test_grouped_conv_kernel_decl.cpp
test_grouped_conv_registry.cpp
)
foreach(test_source ${STANDALONE_TESTS})

View File

@@ -42,10 +42,10 @@ from compile_gemm_examples import ( # noqa: E402
expand_declaration_with_arch_filter,
is_wildcard_declaration,
)
from compile_conv_examples import ( # noqa: E402
validate_conv_kernel_config,
expand_conv_declaration_with_arch_filter,
is_conv_wildcard_declaration,
from compile_grouped_conv_examples import ( # noqa: E402
validate_grouped_conv_kernel_config as validate_conv_kernel_config,
expand_grouped_conv_declaration_with_arch_filter as expand_conv_declaration_with_arch_filter,
is_grouped_conv_wildcard_declaration as is_conv_wildcard_declaration,
)
from ctypes_utils import auto_correct_kernel_config, KernelConfig # noqa: E402

View File

@@ -0,0 +1,244 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Tests for codegen/codegen_common.py -- shared infrastructure for GEMM and grouped conv codegen.
Phase 1a TDD: these tests are written BEFORE the implementation exists.
Run: python3 -m pytest tests/test_codegen_common.py -v
"""
import sys
import unittest
from pathlib import Path
SCRIPT_DIR = Path(__file__).parent.resolve()
DISPATCHER_DIR = SCRIPT_DIR.parent
sys.path.insert(0, str(DISPATCHER_DIR / "codegen"))
from codegen_common import ( # noqa: E402
TileConfig,
TraitConfigBase,
CommonTypeMappings,
generate_cpp_compilation_unit,
parallel_generate,
valid_wave_configs,
valid_warp_configs,
valid_trait_configs,
needs_wave_expansion,
needs_warp_expansion,
needs_pipeline_expansion,
)
class TestTileConfig(unittest.TestCase):
"""TileConfig dataclass tests."""
def test_valid_config(self):
tc = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)
self.assertTrue(tc.is_valid())
def test_zero_tile_invalid(self):
tc = TileConfig(0, 128, 32, 2, 2, 1, 32, 32, 16)
self.assertFalse(tc.is_valid())
def test_non_divisible_invalid(self):
tc = TileConfig(127, 128, 32, 2, 2, 1, 32, 32, 16)
self.assertFalse(tc.is_valid())
def test_all_fields_accessible(self):
tc = TileConfig(256, 128, 64, 4, 1, 1, 32, 32, 16)
self.assertEqual(tc.tile_m, 256)
self.assertEqual(tc.tile_n, 128)
self.assertEqual(tc.tile_k, 64)
self.assertEqual(tc.warp_m, 4)
self.assertEqual(tc.warp_n, 1)
self.assertEqual(tc.warp_k, 1)
self.assertEqual(tc.warp_tile_m, 32)
self.assertEqual(tc.warp_tile_n, 32)
self.assertEqual(tc.warp_tile_k, 16)
def test_small_valid_config(self):
tc = TileConfig(16, 16, 16, 1, 1, 1, 16, 16, 16)
self.assertTrue(tc.is_valid())
class TestTraitConfigBase(unittest.TestCase):
"""TraitConfigBase dataclass tests."""
def test_valid_intrawave(self):
tc = TraitConfigBase("compv3", "cshuffle", "intrawave", False, False, False)
self.assertTrue(tc.is_valid())
def test_invalid_interwave_compv3(self):
tc = TraitConfigBase("compv3", "cshuffle", "interwave", False, False, False)
self.assertFalse(tc.is_valid())
def test_invalid_interwave_compv4(self):
tc = TraitConfigBase("compv4", "cshuffle", "interwave", False, False, False)
self.assertFalse(tc.is_valid())
def test_valid_mem_interwave(self):
tc = TraitConfigBase("mem", "cshuffle", "interwave", False, False, False)
self.assertTrue(tc.is_valid())
def test_valid_mem_intrawave(self):
tc = TraitConfigBase("mem", "cshuffle", "intrawave", False, False, False)
self.assertTrue(tc.is_valid())
def test_padding_fields(self):
tc = TraitConfigBase("compv3", "cshuffle", "intrawave", True, True, True)
self.assertTrue(tc.pad_m)
self.assertTrue(tc.pad_n)
self.assertTrue(tc.pad_k)
class TestCommonTypeMappings(unittest.TestCase):
"""CommonTypeMappings tests."""
def test_dtype_to_ck(self):
self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["fp16"], "fp16_t")
self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["bf16"], "bf16_t")
self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["fp32"], "float")
self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["fp8"], "fp8_t")
def test_pipeline_to_ck(self):
self.assertEqual(
CommonTypeMappings.PIPELINE_TO_CK["mem"], "GemmPipelineAgBgCrMem"
)
self.assertIn("compv3", CommonTypeMappings.PIPELINE_TO_CK)
self.assertIn("compv4", CommonTypeMappings.PIPELINE_TO_CK)
def test_pipeline_to_base(self):
self.assertIn("mem", CommonTypeMappings.PIPELINE_TO_BASE)
self.assertIn("compv3", CommonTypeMappings.PIPELINE_TO_BASE)
self.assertIn("compv4", CommonTypeMappings.PIPELINE_TO_BASE)
def test_scheduler_to_ck(self):
self.assertIn("intrawave", CommonTypeMappings.SCHEDULER_TO_CK)
self.assertIn("interwave", CommonTypeMappings.SCHEDULER_TO_CK)
def test_epilogue_to_dispatcher(self):
self.assertIn("cshuffle", CommonTypeMappings.EPILOGUE_TO_DISPATCHER)
self.assertIn("default", CommonTypeMappings.EPILOGUE_TO_DISPATCHER)
def test_layout_to_ck(self):
self.assertIn("r", CommonTypeMappings.LAYOUT_TO_CK)
self.assertIn("c", CommonTypeMappings.LAYOUT_TO_CK)
def test_get_output_dtype(self):
self.assertEqual(CommonTypeMappings.get_output_dtype("fp8"), "fp16")
self.assertEqual(CommonTypeMappings.get_output_dtype("bf8"), "fp16")
self.assertEqual(CommonTypeMappings.get_output_dtype("fp16"), "fp16")
self.assertEqual(CommonTypeMappings.get_output_dtype("fp32"), "fp32")
class TestGenerateCppCompilationUnit(unittest.TestCase):
"""Tests for generate_cpp_compilation_unit."""
def test_includes_kernel_header(self):
result = generate_cpp_compilation_unit("my_kernel")
self.assertIn('#include "my_kernel.hpp"', result)
def test_contains_pragma_once_or_guard(self):
result = generate_cpp_compilation_unit("test_kernel")
self.assertIn("test_kernel", result)
def test_different_names_different_output(self):
a = generate_cpp_compilation_unit("kernel_a")
b = generate_cpp_compilation_unit("kernel_b")
self.assertNotEqual(a, b)
class TestParallelGenerate(unittest.TestCase):
"""Tests for parallel_generate helper."""
def _dummy_generate(self, item):
return f"generated_{item}"
def test_parallel_returns_all(self):
items = ["a", "b", "c", "d"]
results = parallel_generate(self._dummy_generate, items, parallel=True)
self.assertEqual(len(results), 4)
for item in items:
self.assertIn(f"generated_{item}", results)
def test_sequential_returns_all(self):
items = ["x", "y", "z"]
results = parallel_generate(self._dummy_generate, items, parallel=False)
self.assertEqual(len(results), 3)
for item in items:
self.assertIn(f"generated_{item}", results)
def test_empty_items(self):
results = parallel_generate(self._dummy_generate, [], parallel=True)
self.assertEqual(len(results), 0)
def test_logs_per_kernel_progress(self):
items = ["k1", "k2"]
with self.assertLogs(level="INFO") as cm:
parallel_generate(self._dummy_generate, items, parallel=False)
log_output = "\n".join(cm.output)
self.assertIn("k1", log_output)
self.assertIn("k2", log_output)
class TestArchAwareExpansion(unittest.TestCase):
"""Tests for arch-aware expansion helpers (best-of-conv)."""
def test_valid_wave_configs_gfx942(self):
configs = valid_wave_configs("gfx942")
self.assertIsInstance(configs, list)
self.assertIn([2, 2, 1], configs)
self.assertIn([1, 4, 1], configs)
def test_valid_wave_configs_unknown_arch(self):
configs = valid_wave_configs("gfx_unknown")
self.assertIsInstance(configs, list)
self.assertGreater(len(configs), 0)
def test_valid_warp_configs_gfx942_fp16(self):
configs = valid_warp_configs("gfx942", "fp16")
self.assertIsInstance(configs, list)
self.assertIn([32, 32, 16], configs)
def test_valid_warp_configs_unknown_arch(self):
configs = valid_warp_configs("gfx_unknown", "fp16")
self.assertIsInstance(configs, list)
self.assertGreater(len(configs), 0)
def test_valid_trait_configs_excludes_interwave_compute(self):
configs = valid_trait_configs()
self.assertIsInstance(configs, list)
self.assertNotIn(("compv3", "cshuffle", "interwave"), configs)
self.assertNotIn(("compv4", "cshuffle", "interwave"), configs)
def test_valid_trait_configs_includes_mem_interwave(self):
configs = valid_trait_configs()
has_mem_interwave = any(p == "mem" and s == "interwave" for p, s in configs)
self.assertTrue(has_mem_interwave)
def test_needs_wave_expansion_wildcard(self):
self.assertTrue(needs_wave_expansion({"wave_m": -1, "wave_n": 2}))
self.assertTrue(needs_wave_expansion({"wave_m": 2, "wave_n": -1}))
def test_needs_wave_expansion_explicit(self):
self.assertFalse(needs_wave_expansion({"wave_m": 2, "wave_n": 2}))
def test_needs_warp_expansion_wildcard(self):
self.assertTrue(needs_warp_expansion({"warp_m": -1, "warp_n": 32}))
def test_needs_warp_expansion_explicit(self):
self.assertFalse(needs_warp_expansion({"warp_m": 32, "warp_n": 32}))
def test_needs_pipeline_expansion_wildcard(self):
self.assertTrue(needs_pipeline_expansion({"pipeline": "*"}))
def test_needs_pipeline_expansion_explicit(self):
self.assertFalse(needs_pipeline_expansion({"pipeline": "compv4"}))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,243 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Tests for python/dispatcher_common.py -- shared Python dispatcher utilities.
Phase 1b TDD: tests written BEFORE implementation exists.
Run: python3 -m pytest tests/test_dispatcher_common.py -v
"""
import io
import sys
import unittest
from pathlib import Path
from unittest.mock import patch
SCRIPT_DIR = Path(__file__).parent.resolve()
DISPATCHER_DIR = SCRIPT_DIR.parent
sys.path.insert(0, str(DISPATCHER_DIR / "python"))
sys.path.insert(0, str(DISPATCHER_DIR / "codegen"))
from dispatcher_common import ( # noqa: E402
get_dispatcher_root,
get_ck_root,
get_build_dir,
get_generated_kernels_dir,
get_arch_filter_data,
ValidationResultBase,
validate_wave_config,
validate_warp_tile_config,
validate_trait_combo,
auto_correct_wave,
auto_correct_trait,
Colors,
print_phase,
print_success,
print_error,
print_info,
cleanup_generated_kernels,
)
class TestPathHelpers(unittest.TestCase):
"""Tests for path helper functions."""
def test_dispatcher_root_contains_codegen(self):
root = get_dispatcher_root()
self.assertTrue((root / "codegen").exists())
def test_ck_root_contains_include_or_is_parent(self):
root = get_ck_root()
self.assertTrue(root.exists())
self.assertEqual(root, get_dispatcher_root().parent)
def test_build_dir_is_under_dispatcher(self):
build = get_build_dir()
self.assertEqual(build.parent, get_dispatcher_root())
def test_generated_kernels_dir_under_build(self):
gen_dir = get_generated_kernels_dir()
self.assertEqual(gen_dir.parent, get_build_dir())
class TestGetArchFilterData(unittest.TestCase):
"""Tests for get_arch_filter_data."""
def test_returns_dict(self):
data = get_arch_filter_data()
self.assertIsInstance(data, dict)
def test_has_warp_combos(self):
data = get_arch_filter_data()
self.assertIn("warp_combos", data)
def test_has_warp_tile_combos(self):
data = get_arch_filter_data()
self.assertIn("warp_tile_combos", data)
def test_has_trait_unsupported(self):
data = get_arch_filter_data()
self.assertIn("trait_unsupported", data)
def test_has_supported_archs(self):
data = get_arch_filter_data()
self.assertIn("supported_archs", data)
self.assertIn("gfx942", data["supported_archs"])
def test_gfx942_wave_configs(self):
data = get_arch_filter_data()
gfx942 = data["warp_combos"].get("gfx942", [])
self.assertIn([2, 2, 1], gfx942)
class TestValidationResultBase(unittest.TestCase):
"""Tests for ValidationResultBase dataclass."""
def test_valid_result(self):
vr = ValidationResultBase(is_valid=True)
self.assertTrue(vr.is_valid)
self.assertEqual(vr.errors, [])
self.assertEqual(vr.warnings, [])
self.assertEqual(vr.suggested_fixes, {})
def test_invalid_result(self):
vr = ValidationResultBase(
is_valid=False,
errors=["bad wave"],
suggested_fixes={"wave_m": 2},
)
self.assertFalse(vr.is_valid)
self.assertEqual(len(vr.errors), 1)
self.assertIn("wave_m", vr.suggested_fixes)
class TestValidateWaveConfig(unittest.TestCase):
"""Tests for validate_wave_config."""
def test_valid_wave(self):
is_valid, msg = validate_wave_config([2, 2, 1], "gfx942")
self.assertTrue(is_valid)
self.assertEqual(msg, "")
def test_invalid_wave(self):
is_valid, msg = validate_wave_config([3, 3, 1], "gfx942")
self.assertFalse(is_valid)
self.assertIn("wave", msg.lower())
class TestValidateWarpTileConfig(unittest.TestCase):
"""Tests for validate_warp_tile_config."""
def test_valid_warp_tile(self):
is_valid, msg = validate_warp_tile_config([32, 32, 16], "gfx942", "fp16")
self.assertTrue(is_valid)
def test_invalid_warp_tile(self):
is_valid, msg = validate_warp_tile_config([99, 99, 99], "gfx942", "fp16")
self.assertFalse(is_valid)
self.assertIn("warp", msg.lower())
class TestValidateTraitCombo(unittest.TestCase):
"""Tests for validate_trait_combo."""
def test_valid_trait(self):
is_valid, msg = validate_trait_combo("compv3", "cshuffle", "intrawave")
self.assertTrue(is_valid)
def test_invalid_trait_interwave_compute(self):
is_valid, msg = validate_trait_combo("compv4", "cshuffle", "interwave")
self.assertFalse(is_valid)
def test_valid_mem_interwave(self):
is_valid, msg = validate_trait_combo("mem", "cshuffle", "interwave")
self.assertTrue(is_valid)
class TestAutoCorrectWave(unittest.TestCase):
"""Tests for auto_correct_wave."""
def test_corrects_invalid_wave(self):
corrected = auto_correct_wave([1, 1, 1], "gfx942")
self.assertIsInstance(corrected, list)
self.assertEqual(len(corrected), 3)
data = get_arch_filter_data()
valid_waves = data["warp_combos"].get("gfx942", [[2, 2, 1]])
self.assertIn(corrected, valid_waves)
class TestAutoCorrectTrait(unittest.TestCase):
"""Tests for auto_correct_trait."""
def test_corrects_invalid_scheduler(self):
corrected_pipeline, corrected_scheduler = auto_correct_trait(
"compv4", "interwave"
)
self.assertEqual(corrected_scheduler, "intrawave")
class TestColors(unittest.TestCase):
"""Tests for Colors class (cross-platform ANSI support from conv)."""
def test_green_returns_string(self):
result = Colors.green("ok")
self.assertIn("ok", result)
def test_red_returns_string(self):
result = Colors.red("error")
self.assertIn("error", result)
def test_yellow_returns_string(self):
result = Colors.yellow("warn")
self.assertIn("warn", result)
def test_bold_returns_string(self):
result = Colors.bold("title")
self.assertIn("title", result)
def test_plain_mode_no_ansi(self):
with patch.object(Colors, "_use_color", return_value=False):
result = Colors.green("plain")
self.assertEqual(result, "plain")
class TestPhasedOutput(unittest.TestCase):
"""Tests for phased output helpers."""
def test_print_phase(self):
buf = io.StringIO()
with patch("sys.stdout", buf):
print_phase(1, "Setup")
self.assertIn("Setup", buf.getvalue())
def test_print_success(self):
buf = io.StringIO()
with patch("sys.stdout", buf):
print_success("Done")
self.assertIn("Done", buf.getvalue())
def test_print_error(self):
buf = io.StringIO()
with patch("sys.stdout", buf):
print_error("Oops")
self.assertIn("Oops", buf.getvalue())
def test_print_info(self):
buf = io.StringIO()
with patch("sys.stdout", buf):
print_info("FYI")
self.assertIn("FYI", buf.getvalue())
class TestCleanup(unittest.TestCase):
"""Tests for cleanup_generated_kernels."""
def test_cleanup_nonexistent_dir_no_error(self):
cleanup_generated_kernels(Path("/tmp/nonexistent_ck_test_dir_12345"))
if __name__ == "__main__":
unittest.main()

View File

@@ -28,14 +28,18 @@ sys.path.insert(0, str(PYTHON_DIR))
def run_python_example(
example_path: Path, timeout: int = 120
example_path: Path, timeout: int = 120, extra_args: list = None
) -> subprocess.CompletedProcess:
"""Run a Python example and capture output."""
env = os.environ.copy()
env["PYTHONPATH"] = str(PYTHON_DIR)
cmd = [sys.executable, str(example_path)]
if extra_args:
cmd.extend(extra_args)
return subprocess.run(
[sys.executable, str(example_path)],
cmd,
capture_output=True,
text=True,
timeout=timeout,
@@ -111,61 +115,74 @@ class TestGemmPythonExamples(unittest.TestCase):
result = run_python_example(example)
self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}")
# Should pass validation
self.assertIn("PASS", result.stdout.upper(), "Validation should pass")
class TestConvPythonExamples(unittest.TestCase):
"""Test Conv Python examples."""
"""Test grouped conv Python examples."""
@classmethod
def setUpClass(cls):
"""Check if examples directory exists."""
cls.conv_examples_dir = EXAMPLES_DIR / "conv" / "python"
cls.conv_examples_dir = EXAMPLES_DIR / "grouped_conv" / "python"
if not cls.conv_examples_dir.exists():
raise unittest.SkipTest("Conv Python examples not found")
raise unittest.SkipTest("Grouped conv Python examples not found")
def test_01_basic_conv(self):
"""Test basic conv example."""
example = self.conv_examples_dir / "01_basic_conv.py"
def test_01_basic_grouped_conv(self):
"""Test basic grouped conv example."""
example = self.conv_examples_dir / "01_basic_grouped_conv.py"
if not example.exists():
self.skipTest(f"{example.name} not found")
result = run_python_example(example)
self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}")
self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS")
self.assertIn("PASS", result.stdout.upper())
def test_02_conv2d_fwd(self):
"""Test 2D forward conv example."""
example = self.conv_examples_dir / "02_conv2d_fwd.py"
def test_02_forward(self):
"""Test forward conv example (2D + 3D)."""
example = self.conv_examples_dir / "02_forward.py"
if not example.exists():
self.skipTest(f"{example.name} not found")
result = run_python_example(example)
self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}")
self.assertIn("PASS", result.stdout.upper())
def test_03_conv3d_fwd(self):
"""Test 3D forward conv example."""
example = self.conv_examples_dir / "03_conv3d_fwd.py"
def test_03_bwd_data(self):
"""Test backward data example."""
example = self.conv_examples_dir / "03_bwd_data.py"
if not example.exists():
self.skipTest(f"{example.name} not found")
result = run_python_example(example)
self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}")
self.assertIn("PASS", result.stdout.upper())
def test_07_validation(self):
"""Test validation example."""
example = self.conv_examples_dir / "07_validation.py"
def test_04_bwd_weight(self):
"""Test backward weight example."""
example = self.conv_examples_dir / "04_bwd_weight.py"
if not example.exists():
self.skipTest(f"{example.name} not found")
result = run_python_example(example)
self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}")
self.assertIn("PASS", result.stdout.upper(), "Validation should pass")
self.assertIn("PASS", result.stdout.upper())
def test_05_benchmark(self):
"""Test benchmark example."""
example = self.conv_examples_dir / "05_benchmark.py"
if not example.exists():
self.skipTest(f"{example.name} not found")
result = run_python_example(
example, extra_args=["--warmup", "1", "--repeat", "1"]
)
self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}")
self.assertIn("PASS", result.stdout.upper())
def test_06_registry_json(self):
"""Test registry + heuristic + JSON example."""
example = self.conv_examples_dir / "06_registry_json.py"
if not example.exists():
self.skipTest(f"{example.name} not found")
result = run_python_example(example)
self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}")
self.assertIn("PASS", result.stdout.upper())
class TestGemmCppExamples(unittest.TestCase):
@@ -195,18 +212,18 @@ class TestGemmCppExamples(unittest.TestCase):
self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}")
def test_gemm_04_validation(self):
"""Test validation GEMM C++ example."""
result = run_cpp_example("gemm_04_validation")
def test_gemm_03_benchmark_validation(self):
"""Test benchmark+validation GEMM C++ example."""
result = run_cpp_example("gemm_03_benchmark_validation")
if result is None:
self.skipTest("gemm_04_validation not built")
self.skipTest("gemm_03_benchmark_validation not built")
self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}")
self.assertIn("PASS", result.stdout.upper(), "Validation should pass")
class TestConvCppExamples(unittest.TestCase):
"""Test Conv C++ examples."""
"""Test grouped conv C++ examples."""
@classmethod
def setUpClass(cls):
@@ -215,23 +232,29 @@ class TestConvCppExamples(unittest.TestCase):
if not cls.examples_dir.exists():
raise unittest.SkipTest("C++ examples not built")
def test_conv_01_forward(self):
"""Test forward conv C++ example."""
result = run_cpp_example("conv_01_forward")
def test_grouped_conv_01_basic(self):
"""Test basic grouped conv C++ example."""
result = run_cpp_example("grouped_conv_01_basic")
if result is None:
self.skipTest("conv_01_forward not built")
self.skipTest("grouped_conv_01_basic not built")
self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}")
self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS")
self.assertIn("PASS", result.stdout.upper())
def test_conv_02_validation(self):
"""Test validation conv C++ example."""
result = run_cpp_example("conv_02_validation")
def test_grouped_conv_02_all_dirs(self):
"""Test all directions grouped conv C++ example."""
result = run_cpp_example("grouped_conv_02_all_dirs")
if result is None:
self.skipTest("conv_02_validation not built")
self.skipTest("grouped_conv_02_all_dirs not built")
self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}")
self.assertIn("PASS", result.stdout.upper(), "Validation should pass")
self.assertIn("PASS", result.stdout.upper())
def test_grouped_conv_03_bench_val(self):
"""Test benchmark+validation grouped conv C++ example."""
result = run_cpp_example("grouped_conv_03_bench_val")
if result is None:
self.skipTest("grouped_conv_03_bench_val not built")
self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}")
self.assertIn("PASS", result.stdout.upper())
class TestUtilityImports(unittest.TestCase):
@@ -246,14 +269,18 @@ class TestUtilityImports(unittest.TestCase):
except ImportError as e:
self.fail(f"Failed to import ctypes_utils: {e}")
def test_import_conv_utils(self):
"""Test importing conv_utils."""
def test_import_grouped_conv_utils(self):
"""Test importing grouped_conv_utils."""
try:
from conv_utils import ConvSignature, ConvAlgorithm, ConvProblem # noqa: F401
from grouped_conv_utils import ( # noqa: F401
GroupedConvValidationResult,
validate_grouped_conv_config,
GroupedConvDataType,
)
self.assertTrue(True)
except ImportError as e:
self.fail(f"Failed to import conv_utils: {e}")
self.fail(f"Failed to import grouped_conv_utils: {e}")
def test_kernel_config_creation(self):
"""Test creating a KernelConfig."""
@@ -272,22 +299,19 @@ class TestUtilityImports(unittest.TestCase):
self.assertEqual(config.dtype_a, "fp16")
self.assertEqual(config.layout_a, "row")
def test_conv_signature_creation(self):
"""Test creating a ConvSignature."""
from conv_utils import ConvSignature
def test_grouped_conv_default_config(self):
"""Test creating a grouped conv default config."""
from grouped_conv_utils import get_grouped_conv_default_config
sig = ConvSignature(
dtype_in="fp16",
dtype_wei="fp16",
dtype_out="fp16",
dtype_acc="fp32",
layout="nhwgc",
direction="forward",
num_dims=2,
config = get_grouped_conv_default_config(
variant="forward",
ndim_spatial=2,
arch="gfx942",
)
self.assertEqual(sig.dtype_in, "fp16")
self.assertEqual(sig.direction, "forward")
d = config.to_dict() if hasattr(config, "to_dict") else config
self.assertEqual(d["variant"], "forward")
self.assertEqual(d["arch"], "gfx942")
class TestAutoCorrection(unittest.TestCase):
@@ -316,21 +340,22 @@ class TestAutoCorrection(unittest.TestCase):
self.assertTrue(was_modified, "Config should be modified")
self.assertGreater(len(corrections), 0, "Should have corrections")
def test_conv_auto_correct(self):
"""Test Conv auto-correction."""
from conv_utils import auto_correct_conv_config
# Call with invalid wave config parameters
corrected, was_modified, corrections = auto_correct_conv_config(
wave_m=99, # Invalid
wave_n=99, # Invalid
wave_k=99, # Invalid
dtype="fp16",
arch="gfx942",
def test_grouped_conv_auto_correct(self):
"""Test Grouped Conv auto-correction."""
from grouped_conv_utils import (
auto_correct_grouped_conv_config,
get_grouped_conv_default_config,
)
self.assertTrue(was_modified, "Config should be modified")
self.assertGreater(len(corrections), 0, "Should have corrections")
config = get_grouped_conv_default_config()
d = config.to_dict() if hasattr(config, "to_dict") else config
d["tile_config"]["warp_m"] = [99]
d["tile_config"]["warp_n"] = [99]
corrected, result = auto_correct_grouped_conv_config(d)
self.assertIsInstance(corrected, dict)
self.assertIn("tile_config", corrected)
if __name__ == "__main__":

View File

@@ -0,0 +1,589 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
TDD tests for codegen/unified_grouped_conv_codegen.py -- grouped convolution code generator.
These tests are written BEFORE the implementation exists.
Run: python3 -m pytest dispatcher/tests/test_grouped_conv_codegen.py -v
"""
import sys
import unittest
from pathlib import Path
from unittest.mock import patch
SCRIPT_DIR = Path(__file__).parent.resolve()
DISPATCHER_DIR = SCRIPT_DIR.parent
sys.path.insert(0, str(DISPATCHER_DIR / "codegen"))
sys.path.insert(0, str(DISPATCHER_DIR / "python"))
from codegen_common import TileConfig, TraitConfigBase # noqa: E402
from unified_grouped_conv_codegen import ( # noqa: E402
GroupedConvVariant,
GroupedConvLayout,
GroupedConvKernelConfig,
GroupedConvTypeMappings,
GroupedConvTraitConfig,
CKTileGroupedConvKernelGenerator,
GroupedConvDispatcherWrapperGenerator,
UnifiedGroupedConvCodegen,
)
# =============================================================================
# TestGroupedConvVariant
# =============================================================================
class TestGroupedConvVariant(unittest.TestCase):
"""Test GroupedConvVariant enum values."""
def test_forward_value(self):
self.assertEqual(GroupedConvVariant.FORWARD.value, "forward")
def test_backward_data_value(self):
self.assertEqual(GroupedConvVariant.BACKWARD_DATA.value, "bwd_data")
def test_backward_weight_value(self):
self.assertEqual(GroupedConvVariant.BACKWARD_WEIGHT.value, "bwd_weight")
def test_all_variants_exist(self):
self.assertIn(GroupedConvVariant.FORWARD, GroupedConvVariant)
self.assertIn(GroupedConvVariant.BACKWARD_DATA, GroupedConvVariant)
self.assertIn(GroupedConvVariant.BACKWARD_WEIGHT, GroupedConvVariant)
# =============================================================================
# TestGroupedConvLayout
# =============================================================================
class TestGroupedConvLayout(unittest.TestCase):
"""Test GroupedConvLayout enum for 1D/2D/3D layouts."""
def test_nhwgc_value(self):
self.assertEqual(GroupedConvLayout.NHWGC.value, "NHWGC")
def test_gkyxc_value(self):
self.assertEqual(GroupedConvLayout.GKYXC.value, "GKYXC")
def test_nhwgk_value(self):
self.assertEqual(GroupedConvLayout.NHWGK.value, "NHWGK")
def test_1d_layouts_exist(self):
"""1D conv layouts (e.g., NWGC, GYXC, NWGK)."""
layouts_1d = [
lay
for lay in GroupedConvLayout
if "W" in lay.value and "H" not in lay.value
]
self.assertGreater(len(layouts_1d), 0)
def test_2d_layouts_exist(self):
"""2D conv layouts (e.g., NHWGC, GKYXC, NHWGK)."""
layouts_2d = [lay for lay in GroupedConvLayout if "HW" in lay.value]
self.assertGreater(len(layouts_2d), 0)
def test_3d_layouts_exist(self):
"""3D conv layouts (e.g., NDHWGC, GDKYXC)."""
layouts_3d = [
lay for lay in GroupedConvLayout if "D" in lay.value or "DHW" in lay.value
]
self.assertGreater(len(layouts_3d), 0)
# =============================================================================
# TestGroupedConvKernelConfig
# =============================================================================
class TestGroupedConvKernelConfig(unittest.TestCase):
"""Test GroupedConvKernelConfig dataclass."""
def _make_tile(self):
return TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)
def _make_trait(self):
return GroupedConvTraitConfig(
"mem",
"cshuffle",
"intrawave",
False,
False,
False,
double_smem_buffer=False,
num_groups_to_merge=1,
)
def test_name_contains_grouped_conv_fwd(self):
config = GroupedConvKernelConfig(
tile=self._make_tile(),
trait=self._make_trait(),
variant=GroupedConvVariant.FORWARD,
ndim_spatial=2,
arch="gfx942",
layout=GroupedConvLayout.NHWGC,
vector_sizes=(4, 4, 4),
)
name = config.name("fp16")
self.assertIn("grouped_conv_fwd", name)
def test_name_backward_data_contains_bwd_data(self):
config = GroupedConvKernelConfig(
tile=self._make_tile(),
trait=self._make_trait(),
variant=GroupedConvVariant.BACKWARD_DATA,
ndim_spatial=2,
arch="gfx942",
layout=GroupedConvLayout.NHWGC,
vector_sizes=(4, 4, 4),
)
name = config.name("fp16")
self.assertIn("bwd_data", name)
def test_is_valid_for_arch_supported(self):
config = GroupedConvKernelConfig(
tile=self._make_tile(),
trait=self._make_trait(),
variant=GroupedConvVariant.FORWARD,
ndim_spatial=2,
arch="gfx942",
layout=GroupedConvLayout.NHWGC,
vector_sizes=(4, 4, 4),
)
self.assertTrue(config.is_valid_for_arch("gfx942"))
def test_is_valid_for_arch_unsupported(self):
config = GroupedConvKernelConfig(
tile=self._make_tile(),
trait=self._make_trait(),
variant=GroupedConvVariant.FORWARD,
ndim_spatial=2,
arch="gfx942",
layout=GroupedConvLayout.NHWGC,
vector_sizes=(4, 4, 4),
)
self.assertFalse(config.is_valid_for_arch("gfx600"))
# =============================================================================
# TestGroupedConvTypeMappings
# =============================================================================
class TestGroupedConvTypeMappings(unittest.TestCase):
"""Test GroupedConvTypeMappings class."""
def test_dtype_to_ck_fp16(self):
self.assertEqual(GroupedConvTypeMappings.DTYPE_TO_CK["fp16"], "half_t")
def test_dtype_to_ck_bf16(self):
self.assertIn("bf16", GroupedConvTypeMappings.DTYPE_TO_CK)
def test_dtype_to_ck_fp32(self):
self.assertIn("fp32", GroupedConvTypeMappings.DTYPE_TO_CK)
def test_get_layouts_2d_has_in_wei_out_keys(self):
layouts = GroupedConvTypeMappings.get_layouts(2)
self.assertIn("in", layouts)
self.assertIn("wei", layouts)
self.assertIn("out", layouts)
def test_get_layouts_2d_returns_dict(self):
layouts = GroupedConvTypeMappings.get_layouts(2)
self.assertIsInstance(layouts, dict)
def test_get_layouts_1d(self):
layouts = GroupedConvTypeMappings.get_layouts(1)
self.assertIn("in", layouts)
self.assertIn("wei", layouts)
self.assertIn("out", layouts)
def test_get_layouts_3d(self):
layouts = GroupedConvTypeMappings.get_layouts(3)
self.assertIn("in", layouts)
self.assertIn("wei", layouts)
self.assertIn("out", layouts)
# =============================================================================
# TestCKTileGroupedConvKernelGenerator
# =============================================================================
class TestCKTileGroupedConvKernelGenerator(unittest.TestCase):
"""Test CKTileGroupedConvKernelGenerator.generate()."""
def _make_config(self):
tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)
trait = GroupedConvTraitConfig(
"mem",
"cshuffle",
"intrawave",
False,
False,
False,
double_smem_buffer=False,
num_groups_to_merge=1,
)
return GroupedConvKernelConfig(
tile=tile,
trait=trait,
variant=GroupedConvVariant.FORWARD,
ndim_spatial=2,
arch="gfx942",
layout=GroupedConvLayout.NHWGC,
vector_sizes=(4, 4, 4),
)
def test_generate_contains_pragma_once(self):
gen = CKTileGroupedConvKernelGenerator("fp16")
config = self._make_config()
result = gen.generate(config)
self.assertIn("#pragma once", result)
def test_generate_contains_forward_kernel_include(self):
gen = CKTileGroupedConvKernelGenerator("fp16")
config = self._make_config()
result = gen.generate(config)
self.assertIn("grouped_convolution_forward_kernel.hpp", result)
def test_generate_returns_non_empty_string(self):
gen = CKTileGroupedConvKernelGenerator("fp16")
config = self._make_config()
result = gen.generate(config)
self.assertIsInstance(result, str)
self.assertGreater(len(result), 100)
def test_generate_valid_cpp_structure(self):
gen = CKTileGroupedConvKernelGenerator("fp16")
config = self._make_config()
result = gen.generate(config)
self.assertIn("#include", result)
self.assertIn("ck_tile", result)
# =============================================================================
# TestGroupedConvDispatcherWrapperGenerator
# =============================================================================
class TestGroupedConvDispatcherWrapperGenerator(unittest.TestCase):
"""Test GroupedConvDispatcherWrapperGenerator.generate()."""
def _make_config(self):
tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)
trait = GroupedConvTraitConfig(
"mem",
"cshuffle",
"intrawave",
False,
False,
False,
double_smem_buffer=False,
num_groups_to_merge=1,
)
return GroupedConvKernelConfig(
tile=tile,
trait=trait,
variant=GroupedConvVariant.FORWARD,
ndim_spatial=2,
arch="gfx942",
layout=GroupedConvLayout.NHWGC,
vector_sizes=(4, 4, 4),
)
def test_generate_contains_dispatcher_registration(self):
gen = GroupedConvDispatcherWrapperGenerator("fp16")
config = self._make_config()
kernel_path = DISPATCHER_DIR / "build" / "generated" / "test_kernel.hpp"
output_dir = DISPATCHER_DIR / "build" / "generated"
result = gen.generate(config, kernel_path, output_dir)
self.assertIn("dispatcher", result)
self.assertIn("KernelKey", result)
self.assertIn("KernelInstancePtr", result)
def test_generate_contains_pragma_once(self):
gen = GroupedConvDispatcherWrapperGenerator("fp16")
config = self._make_config()
kernel_path = DISPATCHER_DIR / "build" / "generated" / "test_kernel.hpp"
output_dir = DISPATCHER_DIR / "build" / "generated"
result = gen.generate(config, kernel_path, output_dir)
self.assertIn("#pragma once", result)
def test_generate_valid_cpp(self):
gen = GroupedConvDispatcherWrapperGenerator("fp16")
config = self._make_config()
kernel_path = DISPATCHER_DIR / "build" / "generated" / "test_kernel.hpp"
output_dir = DISPATCHER_DIR / "build" / "generated"
result = gen.generate(config, kernel_path, output_dir)
self.assertIn("#include", result)
self.assertIn("namespace", result)
# =============================================================================
# TestUnifiedGroupedConvCodegen
# =============================================================================
class TestUnifiedGroupedConvCodegen(unittest.TestCase):
"""Test UnifiedGroupedConvCodegen.generate_all()."""
def test_generate_all_returns_dict_with_expected_keys(self):
output_dir = DISPATCHER_DIR / "build" / "generated" / "grouped_conv"
output_dir.mkdir(parents=True, exist_ok=True)
codegen = UnifiedGroupedConvCodegen(
output_dir=output_dir,
datatype="fp16",
ndim_spatial=2,
gpu_target="gfx942",
)
with patch.object(
codegen,
"_get_configs",
return_value=[], # Mock empty config list for fast test
):
results = codegen.generate_all(parallel=False)
self.assertIn("kernels", results)
self.assertIn("wrappers", results)
self.assertIn("failed", results)
self.assertIsInstance(results["kernels"], list)
self.assertIsInstance(results["wrappers"], list)
self.assertIsInstance(results["failed"], list)
def test_generate_all_with_mock_config_produces_output(self):
output_dir = DISPATCHER_DIR / "build" / "generated" / "grouped_conv_test"
output_dir.mkdir(parents=True, exist_ok=True)
codegen = UnifiedGroupedConvCodegen(
output_dir=output_dir,
datatype="fp16",
ndim_spatial=2,
gpu_target="gfx942",
)
# Use a real config - patch the config source to return one config
tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)
trait = GroupedConvTraitConfig(
"mem",
"cshuffle",
"intrawave",
False,
False,
False,
double_smem_buffer=False,
num_groups_to_merge=1,
)
config = GroupedConvKernelConfig(
tile=tile,
trait=trait,
variant=GroupedConvVariant.FORWARD,
ndim_spatial=2,
arch="gfx942",
layout=GroupedConvLayout.NHWGC,
vector_sizes=(4, 4, 4),
)
with patch.object(codegen, "_get_configs", return_value=[config]):
results = codegen.generate_all(parallel=False)
self.assertIsInstance(results, dict)
self.assertIn("kernels", results)
# =============================================================================
# TestSharedImports
# =============================================================================
class TestSharedImports(unittest.TestCase):
"""Verify TileConfig from codegen_common and GroupedConvTraitConfig extends TraitConfigBase."""
def test_tile_config_has_expected_fields(self):
"""TileConfig from codegen_common has tile_m, tile_n, tile_k, etc."""
tc = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)
self.assertEqual(tc.tile_m, 128)
self.assertEqual(tc.tile_n, 128)
self.assertEqual(tc.tile_k, 32)
self.assertEqual(tc.warp_m, 2)
self.assertEqual(tc.warp_n, 2)
self.assertEqual(tc.warp_k, 1)
self.assertEqual(tc.warp_tile_m, 32)
self.assertEqual(tc.warp_tile_n, 32)
self.assertEqual(tc.warp_tile_k, 16)
def test_tile_config_is_from_codegen_common(self):
"""TileConfig used by grouped conv is the same as codegen_common.TileConfig."""
tc = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)
self.assertTrue(tc.is_valid())
def test_grouped_conv_trait_config_extends_trait_config_base(self):
"""GroupedConvTraitConfig extends TraitConfigBase."""
self.assertTrue(issubclass(GroupedConvTraitConfig, TraitConfigBase))
def test_grouped_conv_trait_config_has_double_smem_buffer(self):
"""GroupedConvTraitConfig has double_smem_buffer field."""
trait = GroupedConvTraitConfig(
"mem",
"cshuffle",
"intrawave",
False,
False,
False,
double_smem_buffer=True,
num_groups_to_merge=2,
)
self.assertTrue(trait.double_smem_buffer)
self.assertEqual(trait.num_groups_to_merge, 2)
def test_grouped_conv_trait_config_has_num_groups_to_merge(self):
"""GroupedConvTraitConfig has num_groups_to_merge field."""
trait = GroupedConvTraitConfig(
"mem",
"cshuffle",
"intrawave",
False,
False,
False,
double_smem_buffer=False,
num_groups_to_merge=4,
)
self.assertEqual(trait.num_groups_to_merge, 4)
def test_grouped_conv_trait_config_inherits_base_fields(self):
"""GroupedConvTraitConfig inherits pipeline, epilogue, scheduler from base."""
trait = GroupedConvTraitConfig(
"compv4",
"cshuffle",
"intrawave",
True,
True,
True,
double_smem_buffer=False,
num_groups_to_merge=1,
)
self.assertEqual(trait.pipeline, "compv4")
self.assertEqual(trait.epilogue, "cshuffle")
self.assertEqual(trait.scheduler, "intrawave")
self.assertTrue(trait.pad_m)
self.assertTrue(trait.pad_n)
self.assertTrue(trait.pad_k)
# =============================================================================
# TestTwoStageBwdWeightCodegen
# =============================================================================
def _make_two_stage_config():
"""Helper: create a two-stage bwd_weight config."""
return GroupedConvKernelConfig(
tile=TileConfig(16, 64, 64, 1, 4, 1, 16, 16, 32),
trait=GroupedConvTraitConfig(
pipeline="compv3",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=True,
pad_n=True,
pad_k=True,
two_stage=True,
),
variant=GroupedConvVariant.BACKWARD_WEIGHT,
ndim_spatial=2,
arch="gfx942",
)
class TestTwoStageBwdWeightCodegen(unittest.TestCase):
"""Tests for two-stage backward weight kernel generation."""
def test_kernel_name_contains_2stage(self):
config = _make_two_stage_config()
name = config.name("fp16")
self.assertIn("_2stage", name)
self.assertIn("bwd_weight", name)
def test_single_stage_name_has_no_2stage(self):
config = _make_two_stage_config()
config.trait.two_stage = False
name = config.name("fp16")
self.assertNotIn("_2stage", name)
def test_generate_contains_elementwise_include(self):
config = _make_two_stage_config()
gen = CKTileGroupedConvKernelGenerator(
"fp16", GroupedConvVariant.BACKWARD_WEIGHT
)
code = gen.generate(config)
self.assertIn("elementwise.hpp", code)
def test_generate_contains_workspace_type(self):
config = _make_two_stage_config()
gen = CKTileGroupedConvKernelGenerator(
"fp16", GroupedConvVariant.BACKWARD_WEIGHT
)
code = gen.generate(config)
self.assertIn("WorkspaceDataType", code)
def test_generate_contains_elementwise_kernel(self):
config = _make_two_stage_config()
gen = CKTileGroupedConvKernelGenerator(
"fp16", GroupedConvVariant.BACKWARD_WEIGHT
)
code = gen.generate(config)
self.assertIn("ElementWiseKernel", code)
def test_generate_contains_launch_kernel_time_mask(self):
config = _make_two_stage_config()
gen = CKTileGroupedConvKernelGenerator(
"fp16", GroupedConvVariant.BACKWARD_WEIGHT
)
code = gen.generate(config)
self.assertIn("launch_kernel_time_mask", code)
def test_generate_forces_vector_size_c_to_1(self):
config = _make_two_stage_config()
gen = CKTileGroupedConvKernelGenerator(
"fp16", GroupedConvVariant.BACKWARD_WEIGHT
)
code = gen.generate(config)
self.assertIn("VectorSizeC_TwoStage = 1", code)
def test_generate_contains_workspace_memset(self):
config = _make_two_stage_config()
gen = CKTileGroupedConvKernelGenerator(
"fp16", GroupedConvVariant.BACKWARD_WEIGHT
)
code = gen.generate(config)
self.assertIn("hipMemsetAsync", code)
def test_single_stage_does_not_contain_workspace(self):
config = _make_two_stage_config()
config.trait.two_stage = False
gen = CKTileGroupedConvKernelGenerator(
"fp16", GroupedConvVariant.BACKWARD_WEIGHT
)
code = gen.generate(config)
self.assertNotIn("WorkspaceDataType", code)
self.assertNotIn("ElementWiseKernel", code)
self.assertNotIn("launch_kernel_time_mask", code)
def test_default_configs_include_two_stage(self):
from unified_grouped_conv_codegen import get_default_configs
configs = get_default_configs(
arch="gfx942",
variants=[GroupedConvVariant.BACKWARD_WEIGHT],
ndims=[2],
)
two_stage = [c for c in configs if c.trait.two_stage]
single_stage = [c for c in configs if not c.trait.two_stage]
self.assertGreater(len(two_stage), 0, "Should have two-stage configs")
self.assertGreater(
len(single_stage), 0, "Should still have single-stage configs"
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,112 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/// Unit tests for GroupedConvConfig using assert() and std::cout
#include "ck_tile/dispatcher/grouped_conv_config.hpp"
#include <cassert>
#include <iostream>
using namespace ck_tile::dispatcher;
void test_grouped_conv_direction_enum()
{
std::cout << " test_grouped_conv_direction_enum... ";
assert(GroupedConvSignatureInfo::direction_str(GroupedConvDirection::FORWARD) ==
std::string("fwd"));
assert(GroupedConvSignatureInfo::direction_str(GroupedConvDirection::BACKWARD_DATA) ==
std::string("bwd_data"));
assert(GroupedConvSignatureInfo::direction_str(GroupedConvDirection::BACKWARD_WEIGHT) ==
std::string("bwd_weight"));
std::cout << "PASSED\n";
}
void test_grouped_conv_signature_info()
{
std::cout << " test_grouped_conv_signature_info... ";
GroupedConvSignatureInfo sig;
assert(sig.spatial_dim == 2);
assert(sig.direction == GroupedConvDirection::FORWARD);
assert(sig.in_type == "fp16");
assert(sig.wei_type == "fp16");
assert(sig.out_type == "fp16");
assert(sig.acc_type == "fp32");
assert(sig.num_groups == 1);
sig.in_type = "bf16";
sig.num_groups = 4;
assert(sig.in_type == "bf16");
assert(sig.num_groups == 4);
std::cout << "PASSED\n";
}
void test_grouped_conv_algorithm_info()
{
std::cout << " test_grouped_conv_algorithm_info... ";
GroupedConvAlgorithmInfo algo;
assert(algo.tile.m == 128);
assert(algo.tile.n == 128);
assert(algo.tile.k == 64);
assert(algo.pipeline == PipelineVersion::V4);
assert(algo.scheduler == PipelineScheduler::INTRAWAVE);
assert(GroupedConvAlgorithmInfo::pipeline_str(PipelineVersion::V4) == std::string("compv4"));
assert(GroupedConvAlgorithmInfo::scheduler_str(PipelineScheduler::INTRAWAVE) ==
std::string("intrawave"));
std::cout << "PASSED\n";
}
void test_grouped_conv_config()
{
std::cout << " test_grouped_conv_config... ";
GroupedConvConfig cfg;
std::string name = cfg.name();
assert(!name.empty());
assert(name.find("grouped_conv_") != std::string::npos);
assert(name.find("fwd") != std::string::npos);
assert(name.find("fp16") != std::string::npos);
assert(name.find("2d") != std::string::npos);
std::string brief = cfg.brief();
assert(!brief.empty());
assert(brief.find("2D") != std::string::npos || brief.find("Grouped") != std::string::npos);
std::string detailed = cfg.detailed();
assert(!detailed.empty());
assert(detailed.find("Signature:") != std::string::npos);
assert(detailed.find("Algorithm:") != std::string::npos);
assert(detailed.find("Arch:") != std::string::npos);
std::cout << "PASSED\n";
}
void test_predefined_grouped_conv_configs()
{
std::cout << " test_predefined_grouped_conv_configs... ";
configs::Memory<float> mem_cfg;
assert(mem_cfg.algorithm.pipeline == PipelineVersion::MEMORY);
assert(mem_cfg.algorithm.tile.m == 128);
assert(mem_cfg.algorithm.tile.n == 32);
configs::CompV3_Small<float> compv3_small;
assert(compv3_small.algorithm.pipeline == PipelineVersion::V3);
assert(compv3_small.algorithm.tile.m == 16);
assert(compv3_small.algorithm.tile.n == 64);
configs::CompV4<float> compv4;
assert(compv4.algorithm.pipeline == PipelineVersion::V4);
assert(compv4.algorithm.double_smem_buffer == true);
configs::WMMA<float> wmma_cfg;
assert(wmma_cfg.arch.name == "gfx1100");
std::cout << "PASSED\n";
}
int main()
{
std::cout << "\n=== Test Grouped Conv Config ===\n\n";
test_grouped_conv_direction_enum();
test_grouped_conv_signature_info();
test_grouped_conv_algorithm_info();
test_grouped_conv_config();
test_predefined_grouped_conv_configs();
std::cout << "\n=== All Tests Passed! ===\n\n";
return 0;
}

View File

@@ -0,0 +1,141 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/// Unit tests for GroupedConvKernelDecl using assert() and std::cout
#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp"
#include <cassert>
#include <iostream>
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::grouped_conv_decl;
void test_grouped_conv_signature_builder()
{
std::cout << " test_grouped_conv_signature_builder... ";
GroupedConvSignature sig;
sig.dtype("fp16").layout("nhwc").conv_type("forward").dims(2).groups(4);
assert(sig.dtype_in_ == "fp16");
assert(sig.dtype_wei_ == "fp16");
assert(sig.dtype_out_ == "fp16");
assert(sig.layout_ == "nhwc");
assert(sig.conv_op_ == "forward");
assert(sig.num_dims_ == 2);
assert(sig.groups_ == 4);
assert(sig.op_str() == "fwd");
sig.conv_type("bwd_data");
assert(sig.op_str() == "bwd_data");
sig.conv_type("bwd_weight");
assert(sig.op_str() == "bwd_weight");
std::cout << "PASSED\n";
}
void test_grouped_conv_algorithm_builder()
{
std::cout << " test_grouped_conv_algorithm_builder... ";
GroupedConvAlgorithm algo;
algo.tile(128, 128, 64)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv4")
.scheduler("intrawave");
assert(algo.tile_m_ == 128);
assert(algo.tile_n_ == 128);
assert(algo.tile_k_ == 64);
assert(algo.wave_m_ == 2);
assert(algo.wave_n_ == 2);
assert(algo.warp_m_ == 32);
assert(algo.warp_n_ == 32);
assert(algo.pipeline_ == "compv4");
assert(algo.scheduler_ == "intrawave");
assert(!algo.needs_expansion());
algo.wave_m_ = ANY_INT;
assert(algo.needs_wave_expansion());
std::cout << "PASSED\n";
}
void test_grouped_conv_kernel_decl()
{
std::cout << " test_grouped_conv_kernel_decl... ";
GroupedConvSignature sig;
sig.dtype("fp16").layout("nhwc").conv_type("forward").dims(2);
GroupedConvAlgorithm algo;
algo.tile(128, 128, 64).wave(2, 2, 1).warp(32, 32, 16);
GroupedConvKernelDecl decl(sig, algo, "gfx942");
std::string name = decl.name();
assert(!name.empty());
assert(name.find("grouped_conv_") != std::string::npos);
assert(name.find("fwd") != std::string::npos);
assert(name.find("fp16") != std::string::npos);
assert(name.find("128x128x64") != std::string::npos);
assert(!decl.has_wildcards());
std::cout << "PASSED\n";
}
void test_grouped_conv_kernel_set()
{
std::cout << " test_grouped_conv_kernel_set... ";
GroupedConvKernelSet set;
set.add("fp16", "nhwc", "forward", 128, 128);
assert(set.size() == 1);
set.add("fp16", "nhwc", "forward", 256, 256);
assert(set.size() == 2);
const auto& decls = set.declarations();
assert(decls[0].algorithm.tile_n_ == 128);
assert(decls[0].algorithm.tile_k_ == 128);
assert(decls[1].algorithm.tile_n_ == 256);
assert(decls[1].algorithm.tile_k_ == 256);
set.tag("test_set");
assert(set.tag() == "test_set");
std::cout << "PASSED\n";
}
void test_grouped_conv_kernel_set_merge()
{
std::cout << " test_grouped_conv_kernel_set_merge... ";
GroupedConvKernelSet set1;
set1.add("fp16", "nhwc", "forward", 128, 128);
GroupedConvKernelSet set2;
set2.add("fp16", "nhwc", "forward", 256, 256);
set1.merge(set2);
assert(set1.size() == 2);
assert(set1.declarations()[0].algorithm.tile_n_ == 128);
assert(set1.declarations()[1].algorithm.tile_n_ == 256);
std::cout << "PASSED\n";
}
void test_grouped_conv_kernel_set_registry()
{
std::cout << " test_grouped_conv_kernel_set_registry... ";
auto& reg = GroupedConvKernelSetRegistry::instance();
reg.clear();
GroupedConvKernelSet set;
set.add("fp16", "nhwc", "forward", 128, 128);
reg.register_set("gconv_test", set);
assert(reg.has("gconv_test"));
assert(reg.size() >= 1);
const auto& retrieved = reg.get("gconv_test");
assert(retrieved.size() == 1);
const auto& empty = reg.get("nonexistent");
assert(empty.size() == 0);
reg.clear();
assert(!reg.has("gconv_test"));
std::cout << "PASSED\n";
}
int main()
{
std::cout << "\n=== Test Grouped Conv Kernel Decl ===\n\n";
test_grouped_conv_signature_builder();
test_grouped_conv_algorithm_builder();
test_grouped_conv_kernel_decl();
test_grouped_conv_kernel_set();
test_grouped_conv_kernel_set_merge();
test_grouped_conv_kernel_set_registry();
std::cout << "\n=== All Tests Passed! ===\n\n";
return 0;
}

View File

@@ -0,0 +1,245 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/// Unit tests for GroupedConvProblem using assert() and std::cout
#include "ck_tile/dispatcher/grouped_conv_problem.hpp"
#include <cassert>
#include <iostream>
#include <stdexcept>
using namespace ck_tile::dispatcher;
void test_grouped_conv_problem_defaults()
{
std::cout << " test_grouped_conv_problem_defaults... ";
GroupedConvProblem p;
assert(p.N == 1);
assert(p.C == 64);
assert(p.K == 64);
assert(p.G == 1);
assert(p.Hi() == 28);
assert(p.Wi() == 28);
assert(p.Y() == 3);
assert(p.X() == 3);
assert(p.op == GroupedConvOp::Forward);
assert(p.stride[0] == 1 && p.stride[1] == 1 && p.stride[2] == 1);
assert(p.padding[0] == 0 && p.padding[1] == 0 && p.padding[2] == 0);
assert(p.dilation[0] == 1 && p.dilation[1] == 1 && p.dilation[2] == 1);
std::cout << "PASSED\n";
}
void test_grouped_conv_problem_2d()
{
std::cout << " test_grouped_conv_problem_2d... ";
GroupedConvProblem p(4, 64, 128, 28, 28, 3, 3);
p.compute_output_size();
assert(p.N == 4);
assert(p.C == 64);
assert(p.K == 128);
assert(p.Hi() == 28);
assert(p.Wi() == 28);
assert(p.Y() == 3);
assert(p.X() == 3);
assert(p.Ho() == 26);
assert(p.Wo() == 26);
std::cout << "PASSED\n";
}
void test_grouped_conv_problem_strided()
{
std::cout << " test_grouped_conv_problem_strided... ";
GroupedConvProblem p;
p.N = 1;
p.C = 64;
p.K = 64;
p.G = 1;
p.input_spatial = {1, 14, 14};
p.filter_spatial = {1, 3, 3};
p.stride = {1, 2, 2};
p.padding = {0, 1, 1};
p.dilation = {1, 1, 1};
p.compute_output_size();
assert(p.Ho() == 7);
assert(p.Wo() == 7);
std::cout << "PASSED\n";
}
void test_grouped_conv_problem_grouped()
{
std::cout << " test_grouped_conv_problem_grouped... ";
GroupedConvProblem p;
p.N = 2;
p.C = 64;
p.K = 64;
p.G = 4;
p.input_spatial = {1, 14, 14};
p.filter_spatial = {1, 3, 3};
p.stride = {1, 1, 1};
p.padding = {0, 0, 0};
p.dilation = {1, 1, 1};
p.compute_output_size();
assert(p.G == 4);
assert(p.C % p.G == 0);
assert(p.K % p.G == 0);
assert(p.is_valid());
std::cout << "PASSED\n";
}
void test_grouped_conv_problem_depthwise()
{
std::cout << " test_grouped_conv_problem_depthwise... ";
GroupedConvProblem p;
p.N = 2;
p.C = 64;
p.K = 64;
p.G = 64;
p.input_spatial = {1, 14, 14};
p.filter_spatial = {1, 3, 3};
p.stride = {1, 1, 1};
p.padding = {0, 0, 0};
p.dilation = {1, 1, 1};
p.compute_output_size();
assert(p.is_depthwise());
assert(p.G == p.C && p.G == p.K);
std::cout << "PASSED\n";
}
void test_grouped_conv_problem_pointwise()
{
std::cout << " test_grouped_conv_problem_pointwise... ";
GroupedConvProblem p;
p.N = 2;
p.C = 64;
p.K = 128;
p.G = 1;
p.input_spatial = {1, 14, 14};
p.filter_spatial = {1, 1, 1};
p.stride = {1, 1, 1};
p.padding = {0, 0, 0};
p.dilation = {1, 1, 1};
p.compute_output_size();
assert(p.is_pointwise());
assert(p.Y() == 1 && p.X() == 1);
std::cout << "PASSED\n";
}
void test_grouped_conv_problem_flops()
{
std::cout << " test_grouped_conv_problem_flops... ";
GroupedConvProblem p;
p.N = 2;
p.C = 64;
p.K = 64;
p.G = 1;
p.input_spatial = {1, 14, 14};
p.filter_spatial = {1, 3, 3};
p.stride = {1, 1, 1};
p.padding = {0, 0, 0};
p.dilation = {1, 1, 1};
p.compute_output_size();
double flops = p.get_flops();
assert(flops > 0);
assert(flops == 2.0 * p.N * p.K * p.Ho() * p.Wo() * (p.C / p.G) * p.Y() * p.X());
std::cout << "PASSED\n";
}
void test_grouped_conv_problem_is_valid()
{
std::cout << " test_grouped_conv_problem_is_valid... ";
GroupedConvProblem p;
p.N = 1;
p.C = 64;
p.K = 64;
p.G = 1;
p.input_spatial = {1, 14, 14};
p.filter_spatial = {1, 3, 3};
p.compute_output_size();
assert(p.is_valid());
p.N = 0;
assert(!p.is_valid());
p.N = 1;
p.C = 0;
assert(!p.is_valid());
p.C = 64;
p.K = 0;
assert(!p.is_valid());
p.K = 64;
p.G = 0;
assert(!p.is_valid());
p.G = 1;
p.C = 64;
p.K = 64;
p.G = 3;
assert(!p.is_valid());
p.G = 4;
assert(p.is_valid());
std::cout << "PASSED\n";
}
void test_grouped_conv_problem_builder()
{
std::cout << " test_grouped_conv_problem_builder... ";
auto p = GroupedConvProblemBuilder()
.batch(8)
.channels(128, 256)
.groups(4)
.input_size(32, 32)
.filter_size(3, 3)
.stride(2, 2)
.padding(1, 1)
.dilation(1, 1)
.operation(GroupedConvOp::Forward)
.build();
assert(p.N == 8);
assert(p.C == 128);
assert(p.K == 256);
assert(p.G == 4);
assert(p.Hi() == 32);
assert(p.Wi() == 32);
assert(p.Y() == 3);
assert(p.X() == 3);
assert(p.stride[1] == 2 && p.stride[2] == 2);
assert(p.padding[1] == 1 && p.padding[2] == 1);
assert(p.op == GroupedConvOp::Forward);
assert(p.is_valid());
bool threw = false;
try
{
(void)GroupedConvProblemBuilder()
.batch(0)
.channels(64, 64)
.groups(1)
.input_size(14, 14)
.filter_size(3, 3)
.build();
}
catch(const std::invalid_argument&)
{
threw = true;
}
assert(threw);
std::cout << "PASSED\n";
}
int main()
{
std::cout << "\n=== Test Grouped Conv Problem ===\n\n";
test_grouped_conv_problem_defaults();
test_grouped_conv_problem_2d();
test_grouped_conv_problem_strided();
test_grouped_conv_problem_grouped();
test_grouped_conv_problem_depthwise();
test_grouped_conv_problem_pointwise();
test_grouped_conv_problem_flops();
test_grouped_conv_problem_is_valid();
test_grouped_conv_problem_builder();
std::cout << "\n=== All Tests Passed! ===\n\n";
return 0;
}

View File

@@ -0,0 +1,230 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/// Unit tests for GroupedConvRegistry and GroupedConvDispatcher using assert() and std::cout
#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
#include <cassert>
#include <iostream>
#include <thread>
#include <atomic>
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::grouped_conv_decl;
void test_grouped_conv_registry_basic()
{
std::cout << " test_grouped_conv_registry_basic... ";
GroupedConvRegistry& reg = GroupedConvRegistry::instance();
reg.clear();
reg.set_name("test_registry");
assert(reg.name() == "test_registry");
assert(reg.size() == 0);
assert(reg.empty());
reg.clear();
std::cout << "PASSED\n";
}
void test_grouped_conv_registry_register_set()
{
std::cout << " test_grouped_conv_registry_register_set... ";
GroupedConvRegistry& reg = GroupedConvRegistry::instance();
reg.clear();
GroupedConvKernelSet set;
set.add("fp16", "nhwc", "forward", 128, 128);
set.add("fp16", "nhwc", "forward", 256, 256);
bool ok = reg.register_set(set);
assert(ok);
assert(reg.size() == 2);
assert(!reg.empty());
reg.clear();
std::cout << "PASSED\n";
}
void test_grouped_conv_registry_all_kernels()
{
std::cout << " test_grouped_conv_registry_all_kernels... ";
GroupedConvRegistry& reg = GroupedConvRegistry::instance();
reg.clear();
GroupedConvKernelSet set;
set.add("fp16", "nhwc", "forward", 128, 128);
reg.register_set(set);
auto all = reg.all_kernels();
assert(all.size() == 1);
assert(all[0]->name().find("grouped_conv_") != std::string::npos);
reg.clear();
std::cout << "PASSED\n";
}
void test_grouped_conv_registry_clear()
{
std::cout << " test_grouped_conv_registry_clear... ";
GroupedConvRegistry& reg = GroupedConvRegistry::instance();
reg.clear();
GroupedConvKernelSet set;
set.add("fp16", "nhwc", "forward", 128, 128);
reg.register_set(set);
assert(reg.size() == 1);
reg.clear();
assert(reg.size() == 0);
assert(reg.empty());
reg.clear();
std::cout << "PASSED\n";
}
void test_grouped_conv_registry_thread_safe()
{
std::cout << " test_grouped_conv_registry_thread_safe... ";
GroupedConvRegistry& reg = GroupedConvRegistry::instance();
reg.clear();
const int num_threads = 4;
const int sets_per_thread = 10;
std::vector<std::thread> threads;
std::atomic<int> success_count{0};
for(int t = 0; t < num_threads; t++)
{
threads.emplace_back([t, &reg, &success_count]() {
for(int k = 0; k < sets_per_thread; k++)
{
GroupedConvKernelSet set;
set.add("fp16", "nhwc", "forward", 128 + t * 32 + k, 128);
if(reg.register_set(set))
{
success_count++;
}
}
});
}
for(auto& th : threads)
th.join();
assert(reg.size() == num_threads * sets_per_thread);
assert(success_count.load() == num_threads * sets_per_thread);
reg.clear();
std::cout << "PASSED\n";
}
void test_grouped_conv_registry_export_json()
{
std::cout << " test_grouped_conv_registry_export_json... ";
GroupedConvRegistry& reg = GroupedConvRegistry::instance();
reg.clear();
GroupedConvKernelSet set;
set.add("fp16", "nhwc", "forward", 128, 128);
reg.register_set(set);
std::string json = reg.export_json(false);
assert(!json.empty());
assert(json.find("\"kernels\"") != std::string::npos);
assert(json.find("\"metadata\"") != std::string::npos);
assert(json.find("grouped_conv_") != std::string::npos);
std::string json_stats = reg.export_json(true);
assert(json_stats.find("\"statistics\"") != std::string::npos);
reg.clear();
std::cout << "PASSED\n";
}
void test_grouped_conv_registry_filter()
{
std::cout << " test_grouped_conv_registry_filter... ";
GroupedConvRegistry& reg = GroupedConvRegistry::instance();
reg.clear();
GroupedConvKernelSet set;
set.add("fp16", "nhwc", "forward", 128, 128);
set.add("fp16", "nhwc", "forward", 256, 256);
set.add("bf16", "nhwc", "forward", 128, 128);
reg.register_set(set);
auto fp16_only =
reg.filter([](const GroupedConvKernelInstance& k) { return k.key().dtype_in == "fp16"; });
assert(fp16_only.size() == 2);
auto large_tile = reg.filter([](const GroupedConvKernelInstance& k) {
return k.key().tile_m >= 256 || k.key().tile_n >= 256;
});
assert(large_tile.size() >= 1);
reg.clear();
std::cout << "PASSED\n";
}
void test_grouped_conv_dispatcher_basic()
{
std::cout << " test_grouped_conv_dispatcher_basic... ";
GroupedConvRegistry& reg = GroupedConvRegistry::instance();
reg.clear();
GroupedConvKernelSet set;
set.add("fp16", "nhwc", "forward", 128, 128);
reg.register_set(set);
GroupedConvDispatcher dispatcher(&reg);
GroupedConvProblem problem = grouped_conv_utils::create_grouped_conv2d_problem(
4, 64, 128, 28, 28, 3, 3, 1, 1, GroupedConvOp::Forward);
float time = dispatcher.run(problem, nullptr);
assert(time >= 0.0f);
reg.clear();
std::cout << "PASSED\n";
}
void test_grouped_conv_dispatcher_select()
{
std::cout << " test_grouped_conv_dispatcher_select... ";
GroupedConvRegistry& reg = GroupedConvRegistry::instance();
reg.clear();
GroupedConvKernelSet set;
set.add("fp16", "nhwc", "forward", 128, 128);
set.add("fp16", "nhwc", "forward", 256, 256);
reg.register_set(set);
GroupedConvDispatcher dispatcher(&reg);
GroupedConvProblem problem = grouped_conv_utils::create_grouped_conv2d_problem(
4, 64, 128, 28, 28, 3, 3, 1, 1, GroupedConvOp::Forward);
const auto* selected = dispatcher.select(problem);
assert(selected != nullptr);
assert(selected->name().find("grouped_conv_") != std::string::npos);
assert(selected->matches(problem));
reg.clear();
std::cout << "PASSED\n";
}
int main()
{
std::cout << "\n=== Test Grouped Conv Registry ===\n\n";
test_grouped_conv_registry_basic();
test_grouped_conv_registry_register_set();
test_grouped_conv_registry_all_kernels();
test_grouped_conv_registry_clear();
test_grouped_conv_registry_thread_safe();
test_grouped_conv_registry_export_json();
test_grouped_conv_registry_filter();
test_grouped_conv_dispatcher_basic();
test_grouped_conv_dispatcher_select();
std::cout << "\n=== All Tests Passed! ===\n\n";
return 0;
}

View File

@@ -0,0 +1,349 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
TDD tests for python/grouped_conv_utils.py -- grouped convolution Python utilities.
Phase 1 TDD: tests written BEFORE implementation exists.
Run: python3 -m pytest tests/test_grouped_conv_utils.py -v
"""
import sys
import unittest
from pathlib import Path
SCRIPT_DIR = Path(__file__).parent.resolve()
DISPATCHER_DIR = SCRIPT_DIR.parent
sys.path.insert(0, str(DISPATCHER_DIR / "python"))
sys.path.insert(0, str(DISPATCHER_DIR / "codegen"))
from dispatcher_common import ValidationResultBase # noqa: E402
from grouped_conv_utils import ( # noqa: E402
GroupedConvValidationResult,
validate_grouped_conv_config,
auto_correct_grouped_conv_config,
get_grouped_conv_default_config,
GroupedConvDataType,
format_grouped_conv_summary,
)
# =============================================================================
# VALID CONFIG FIXTURES
# =============================================================================
def make_valid_grouped_conv_config():
"""Return a valid grouped conv config dict for gfx942."""
return {
"tile_config": {
"tile_k": 128,
"tile_c": 128,
"wave_m": 2,
"wave_n": 2,
"wave_k": 1,
"warp_m": 32,
"warp_n": 32,
"warp_k": 16,
},
"trait_config": {
"pipeline": "compv4",
"epilogue": "cshuffle",
"scheduler": "intrawave",
},
"variant": "2d_fwd",
"ndim_spatial": 2,
"arch": "gfx942",
"layout": "nhwgc",
"dtype": "fp16",
}
# =============================================================================
# TestGroupedConvValidationResult
# =============================================================================
class TestGroupedConvValidationResult(unittest.TestCase):
"""Tests for GroupedConvValidationResult dataclass."""
def test_inherits_from_validation_result_base(self):
"""GroupedConvValidationResult should inherit from ValidationResultBase."""
self.assertTrue(
issubclass(GroupedConvValidationResult, ValidationResultBase),
"GroupedConvValidationResult must inherit from ValidationResultBase",
)
def test_valid_result_has_is_valid(self):
"""Valid result has is_valid=True."""
vr = GroupedConvValidationResult(is_valid=True)
self.assertTrue(vr.is_valid)
def test_invalid_result_has_is_valid_false(self):
"""Invalid result has is_valid=False."""
vr = GroupedConvValidationResult(is_valid=False, errors=["bad config"])
self.assertFalse(vr.is_valid)
def test_has_errors_list(self):
"""Result has errors list."""
vr = GroupedConvValidationResult(
is_valid=False,
errors=["invalid wave", "invalid trait"],
)
self.assertEqual(len(vr.errors), 2)
self.assertIn("invalid wave", vr.errors)
self.assertIn("invalid trait", vr.errors)
def test_has_warnings_list(self):
"""Result has warnings list."""
vr = GroupedConvValidationResult(
is_valid=True,
warnings=["deprecated option"],
)
self.assertEqual(len(vr.warnings), 1)
self.assertIn("deprecated option", vr.warnings)
def test_has_suggested_fixes_dict(self):
"""Result has suggested_fixes dict."""
vr = GroupedConvValidationResult(
is_valid=False,
suggested_fixes={"wave_m": 2, "wave_n": 2},
)
self.assertIn("wave_m", vr.suggested_fixes)
self.assertEqual(vr.suggested_fixes["wave_m"], 2)
self.assertIn("wave_n", vr.suggested_fixes)
self.assertEqual(vr.suggested_fixes["wave_n"], 2)
def test_default_empty_errors_warnings_fixes(self):
"""Default result has empty errors, warnings, suggested_fixes."""
vr = GroupedConvValidationResult(is_valid=True)
self.assertEqual(vr.errors, [])
self.assertEqual(vr.warnings, [])
self.assertEqual(vr.suggested_fixes, {})
# =============================================================================
# TestValidateGroupedConvConfig
# =============================================================================
class TestValidateGroupedConvConfig(unittest.TestCase):
"""Tests for validate_grouped_conv_config."""
def test_valid_config_passes(self):
"""Valid config should pass validation."""
config = make_valid_grouped_conv_config()
result = validate_grouped_conv_config(config)
self.assertTrue(result.is_valid, f"Expected valid, got errors: {result.errors}")
self.assertEqual(result.errors, [])
def test_invalid_wave_config_fails(self):
"""Invalid wave config should fail validation."""
config = make_valid_grouped_conv_config()
config["tile_config"]["wave_m"] = 3
config["tile_config"]["wave_n"] = 3
result = validate_grouped_conv_config(config)
self.assertFalse(result.is_valid)
self.assertGreater(len(result.errors), 0)
error_str = " ".join(result.errors).lower()
self.assertIn("wave", error_str)
def test_invalid_trait_fails(self):
"""Invalid trait combination should fail validation."""
config = make_valid_grouped_conv_config()
config["trait_config"]["pipeline"] = "compv4"
config["trait_config"]["epilogue"] = "cshuffle"
config["trait_config"]["scheduler"] = "interwave" # Invalid combo
result = validate_grouped_conv_config(config)
self.assertFalse(result.is_valid)
self.assertGreater(len(result.errors), 0)
error_str = " ".join(result.errors).lower()
self.assertIn("trait", error_str)
def test_missing_fields_fails(self):
"""Config with missing required fields should fail validation."""
config = {"arch": "gfx942"} # Missing tile_config, trait_config, etc.
result = validate_grouped_conv_config(config)
self.assertFalse(result.is_valid)
self.assertGreater(len(result.errors), 0)
# =============================================================================
# TestAutoCorrectGroupedConvConfig
# =============================================================================
class TestAutoCorrectGroupedConvConfig(unittest.TestCase):
"""Tests for auto_correct_grouped_conv_config."""
def test_invalid_wave_gets_corrected(self):
"""Invalid wave config should be auto-corrected."""
config = make_valid_grouped_conv_config()
config["tile_config"]["wave_m"] = 3
config["tile_config"]["wave_n"] = 3
corrected, result = auto_correct_grouped_conv_config(config)
self.assertIsInstance(corrected, dict)
self.assertIsInstance(result, GroupedConvValidationResult)
# Corrected wave should be valid for arch
wave_m = corrected.get("tile_config", {}).get("wave_m")
wave_n = corrected.get("tile_config", {}).get("wave_n")
self.assertIn(wave_m, [1, 2, 4])
self.assertIn(wave_n, [1, 2, 4])
def test_invalid_trait_gets_corrected(self):
"""Invalid trait combination should be auto-corrected."""
config = make_valid_grouped_conv_config()
config["trait_config"]["scheduler"] = "interwave"
config["trait_config"]["pipeline"] = "compv4"
config["trait_config"]["epilogue"] = "cshuffle"
corrected, result = auto_correct_grouped_conv_config(config)
self.assertIsInstance(corrected, dict)
self.assertIsInstance(result, GroupedConvValidationResult)
# Scheduler should be corrected to intrawave for compv4+cshuffle
scheduler = corrected.get("trait_config", {}).get("scheduler")
self.assertEqual(scheduler, "intrawave")
# =============================================================================
# TestGetGroupedConvDefaultConfig
# =============================================================================
class TestGetGroupedConvDefaultConfig(unittest.TestCase):
"""Tests for get_grouped_conv_default_config."""
def test_returns_config(self):
"""Should return a GroupedConvKernelConfig (or dict via to_dict)."""
config = get_grouped_conv_default_config("2d_fwd")
# Accepts both dataclass and dict
d = config.to_dict() if hasattr(config, "to_dict") else config
self.assertIsInstance(d, dict)
def test_has_tile_config(self):
"""Returned config has tile_config key."""
config = get_grouped_conv_default_config("2d_fwd")
d = config.to_dict() if hasattr(config, "to_dict") else config
self.assertIn("tile_config", d)
self.assertIsInstance(d["tile_config"], dict)
def test_has_trait_config(self):
"""Returned config has trait_config key."""
config = get_grouped_conv_default_config("2d_fwd")
d = config.to_dict() if hasattr(config, "to_dict") else config
self.assertIn("trait_config", d)
self.assertIsInstance(d["trait_config"], dict)
def test_has_variant(self):
"""Returned config has variant."""
config = get_grouped_conv_default_config("2d_fwd")
d = config.to_dict() if hasattr(config, "to_dict") else config
self.assertIn("variant", d)
def test_has_ndim_spatial(self):
"""Returned config has ndim_spatial."""
config = get_grouped_conv_default_config("2d_fwd")
d = config.to_dict() if hasattr(config, "to_dict") else config
self.assertIn("ndim_spatial", d)
def test_has_arch(self):
"""Returned config has arch."""
config = get_grouped_conv_default_config("2d_fwd")
d = config.to_dict() if hasattr(config, "to_dict") else config
self.assertIn("arch", d)
def test_has_layout(self):
"""Returned config has layout."""
config = get_grouped_conv_default_config("2d_fwd")
d = config.to_dict() if hasattr(config, "to_dict") else config
self.assertIn("layout", d)
# =============================================================================
# TestGroupedConvDataType
# =============================================================================
class TestGroupedConvDataType(unittest.TestCase):
"""Tests for GroupedConvDataType enum."""
def test_fp16_exists(self):
"""GroupedConvDataType has FP16."""
self.assertIsNotNone(GroupedConvDataType.FP16)
def test_bf16_exists(self):
"""GroupedConvDataType has BF16."""
self.assertIsNotNone(GroupedConvDataType.BF16)
def test_fp32_exists(self):
"""GroupedConvDataType has FP32."""
self.assertIsNotNone(GroupedConvDataType.FP32)
def test_fp8_exists(self):
"""GroupedConvDataType has FP8."""
self.assertIsNotNone(GroupedConvDataType.FP8)
def test_bf8_exists(self):
"""GroupedConvDataType has BF8."""
self.assertIsNotNone(GroupedConvDataType.BF8)
def test_int8_exists(self):
"""GroupedConvDataType has INT8."""
self.assertIsNotNone(GroupedConvDataType.INT8)
def test_enum_values_unique(self):
"""All enum values should be unique."""
values = [
GroupedConvDataType.FP16,
GroupedConvDataType.BF16,
GroupedConvDataType.FP32,
GroupedConvDataType.FP8,
GroupedConvDataType.BF8,
GroupedConvDataType.INT8,
]
self.assertEqual(len(values), len(set(values)))
# =============================================================================
# TestFormatGroupedConvSummary
# =============================================================================
class TestFormatGroupedConvSummary(unittest.TestCase):
"""Tests for format_grouped_conv_summary."""
def test_returns_non_empty_string(self):
"""Should return a non-empty string."""
config = make_valid_grouped_conv_config()
summary = format_grouped_conv_summary(config)
self.assertIsInstance(summary, str)
self.assertGreater(len(summary), 0)
def test_contains_key_info(self):
"""Summary should contain key config info (variant, arch, layout, dtype)."""
config = make_valid_grouped_conv_config()
summary = format_grouped_conv_summary(config)
# Should mention at least some of: variant, arch, layout, dtype
summary_lower = summary.lower()
has_key_info = (
"2d" in summary_lower
or "fwd" in summary_lower
or "gfx" in summary_lower
or "nhwgc" in summary_lower
or "fp16" in summary_lower
)
self.assertTrue(
has_key_info,
f"Summary should contain key info, got: {summary}",
)
def test_empty_config_returns_something(self):
"""Empty or minimal config should still return a string."""
summary = format_grouped_conv_summary({})
self.assertIsInstance(summary, str)
self.assertGreaterEqual(len(summary), 0)
if __name__ == "__main__":
unittest.main()

View File

@@ -19,7 +19,7 @@ class ProblemDimensionInferenceTest : public ::testing::Test
TEST_F(ProblemDimensionInferenceTest, FromAB_Basic)
{
// A: M×K (1024×512), B: K×N (512×2048)
// A: MxK (1024x512), B: KxN (512x2048)
auto problem = Problem::from_ab(1024, 512, 512, 2048);
EXPECT_EQ(problem.M, 1024);
@@ -30,7 +30,7 @@ TEST_F(ProblemDimensionInferenceTest, FromAB_Basic)
TEST_F(ProblemDimensionInferenceTest, FromDimensions_Valid)
{
// A: 1024×512, B: 512×2048, C: 1024×2048
// A: 1024x512, B: 512x2048, C: 1024x2048
auto problem = Problem::from_dimensions(1024, 512, 512, 2048, 1024, 2048);
EXPECT_EQ(problem.M, 1024);
@@ -55,7 +55,7 @@ TEST_F(ProblemDimensionInferenceTest, FromShapes_WithC)
TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedA)
{
// A stored as K×M (transposed)
// A stored as KxM (transposed)
TensorShape A{512, 1024, true};
TensorShape B{512, 2048, false};
TensorShape C{1024, 2048, false};
@@ -70,7 +70,7 @@ TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedA)
TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedB)
{
TensorShape A{1024, 512, false};
// B stored as N×K (transposed)
// B stored as NxK (transposed)
TensorShape B{2048, 512, true};
TensorShape C{1024, 2048, false};

View File

@@ -187,7 +187,7 @@ int main()
for(const auto& r : results)
{
char size_str[32];
snprintf(size_str, sizeof(size_str), "%4d×%4d×%4d", r.M, r.N, r.K);
snprintf(size_str, sizeof(size_str), "%4dx%4dx%4d", r.M, r.N, r.K);
printf(" %-14s | %9.4f | %6.2f | %7.2f%% | %s\n",
size_str,

View File

@@ -144,7 +144,7 @@ int main()
all_passed = all_passed && passed;
char size_label[32];
snprintf(size_label, sizeof(size_label), "%s %d³", label, M);
snprintf(size_label, sizeof(size_label), "%s %d^3", label, M);
printf(" %-9s | %9.4f | %6.2f | %9.1f | %s\n",
size_label,