diff --git a/dispatcher/README.md b/dispatcher/README.md index 1395285d60..dc864f7c62 100644 --- a/dispatcher/README.md +++ b/dispatcher/README.md @@ -1,6 +1,6 @@ # CK Tile Dispatcher -A unified kernel dispatch system for AMD GPUs with C++ and Python frontends. +A unified kernel dispatch system for AMD GPUs with C++ and Python frontends, supporting GEMM and Grouped Convolution operations. **Validated Platform:** AMD Instinct MI300 series (gfx942) @@ -342,8 +342,8 @@ ls examples/libdispatcher_gemm_lib.so | `CMAKE_PREFIX_PATH` | - | ROCm installation path | | `CMAKE_CXX_COMPILER` | - | Path to hipcc compiler | -⚠️ **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower. -⚠️ **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories). +WARNING: **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower. +WARNING: **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories). --- @@ -363,6 +363,15 @@ cd build/examples ./gemm_04_heuristics # Heuristic kernel selection ./gemm_05_json_export # Registry JSON export ./gemm_06_multi_registry # Multiple registries + +# Grouped Convolution Examples +./grouped_conv_01_basic # Declaration patterns + GPU execution +./grouped_conv_02_all_dirs # Forward/BwdData/BwdWeight with GPU +./grouped_conv_03_bench_val # Benchmark + CPU reference validation +./grouped_conv_04_registry_json # Heuristic selection + JSON export +./grouped_conv_05_bwd_data # Backward data + CPU validation +./grouped_conv_06_bwd_weight # Backward weight + CPU validation +./grouped_conv_07_benchmark # Multi-tile ResNet benchmark ``` ### Python Examples @@ -375,8 +384,16 @@ cd /path/to/composable_kernel/dispatcher # GEMM Examples python3 examples/gemm/python/01_basic_gemm.py # Basic multi-kernel GEMM python3 examples/gemm/python/04_validation.py # CPU reference validation -python3 examples/gemm/python/07_stress_test.py # Stress test (48 kernels) +python3 examples/gemm/python/07_stress_test.py # Stress test python3 examples/gemm/python/08_heuristics.py # Heuristic selection + +# Grouped Convolution Examples +python3 examples/grouped_conv/python/01_basic_grouped_conv.py # Config patterns + registry + GPU +python3 examples/grouped_conv/python/02_forward.py # Forward 2D/3D + CPU ref +python3 examples/grouped_conv/python/03_bwd_data.py # Backward data + CPU ref +python3 examples/grouped_conv/python/04_bwd_weight.py # Backward weight + CPU ref +python3 examples/grouped_conv/python/05_benchmark.py # Multi-problem benchmark +python3 examples/grouped_conv/python/06_registry_json.py # Heuristic selection + JSON ``` ### Example Output @@ -647,7 +664,7 @@ lib = DispatcherLib.load("/absolute/path/to/libdispatcher_gemm_lib.so") ### Data Flow ``` -KernelConfig → Registry → Dispatcher → GPU Execution +KernelConfig -> Registry -> Dispatcher -> GPU Execution ``` 1. **KernelConfig**: Defines kernel parameters (tile sizes, data types, layouts) @@ -843,31 +860,49 @@ make -j$(nproc) ``` dispatcher/ -├── README.md # This file -├── CMakeLists.txt # Build configuration -│ -├── include/ck_tile/dispatcher/ # C++ headers -│ ├── dispatcher.hpp # GEMM dispatcher -│ ├── registry.hpp # Kernel registry -│ └── kernel_key.hpp # Kernel configuration -│ -├── src/ # C++ implementation -│ -├── codegen/ # Kernel generation -│ ├── unified_gemm_codegen.py # GEMM kernel generator -│ └── arch_specs.json # GPU specifications -│ -├── bindings/ctypes/ # Python ctypes interface -│ └── gemm_ctypes_lib.cpp # GEMM Python library -│ -├── examples/ # Examples -│ └── gemm/ -│ ├── cpp/ # C++ GEMM examples (01-06) -│ └── python/ # Python GEMM examples (01-11) -│ -├── scripts/ # Build scripts -│ -└── tests/ # Unit tests +|---- README.md # This file +|---- CMakeLists.txt # Build configuration +| +|---- include/ck_tile/dispatcher/ # C++ headers +| |---- dispatcher.hpp # Main dispatcher include +| |---- registry.hpp # GEMM kernel registry +| |---- kernel_key.hpp # Kernel configuration +| |---- grouped_conv_config.hpp # Grouped conv configuration +| |---- grouped_conv_problem.hpp # Grouped conv problem (with builder) +| |---- grouped_conv_kernel_decl.hpp # Grouped conv kernel declarations +| |---- grouped_conv_registry.hpp # Grouped conv registry (thread-safe) +| +---- grouped_conv_utils.hpp # Grouped conv utilities +| +|---- src/ # C++ implementation +| +|---- codegen/ # Kernel generation +| |---- codegen_common.py # Shared: TileConfig, TraitConfigBase, type mappings +| |---- unified_gemm_codegen.py # GEMM kernel generator +| |---- unified_grouped_conv_codegen.py # Grouped conv kernel generator +| +---- arch_specs.json # GPU specifications +| +|---- python/ # Python utilities +| |---- dispatcher_common.py # Shared: paths, validation, Colors, phased output +| |---- ctypes_utils.py # GEMM ctypes utilities +| +---- grouped_conv_utils.py # Grouped conv utilities +| +|---- scripts/ # Build scripts +| |---- compile_gemm_examples.py # GEMM build script +| +---- compile_grouped_conv_examples.py # Grouped conv build script +| +|---- bindings/ctypes/ # Python ctypes interface +| |---- gemm_ctypes_lib.cpp # GEMM Python library +| +---- conv_ctypes_lib.cpp # Grouped conv Python library +| +|---- examples/ # Examples +| |---- gemm/ +| | |---- cpp/ # C++ GEMM examples (01-07) +| | +---- python/ # Python GEMM examples (01-11) +| +---- grouped_conv/ +| |---- cpp/ # C++ Grouped Conv examples (01-07) +| +---- python/ # Python Grouped Conv examples (01-06) +| ++---- tests/ # Unit tests (C++ and Python) ``` --- @@ -879,17 +914,49 @@ dispatcher/ | GEMM C++ | [examples/gemm/cpp/README.md](examples/gemm/cpp/README.md) | | GEMM Python | [examples/gemm/python/README.md](examples/gemm/python/README.md) | | Codegen | [codegen/README.md](codegen/README.md) | +| Python Utils | [python/README.md](python/README.md) | +| C++ Headers | [include/ck_tile/dispatcher/README.md](include/ck_tile/dispatcher/README.md) | --- -## Archived Content +## Grouped Convolution Support -Convolution examples and utilities have been archived to `ck-2/conv_archive/dispatcher/`: -- `examples/conv/cpp/` - 11 C++ convolution examples -- `examples/conv/python/` - 14 Python convolution examples -- `codegen/unified_conv_codegen.py` - Conv kernel generator -- `include/ck_tile/dispatcher/conv_*.hpp` - Conv headers -- `python/conv_utils.py` - Conv Python utilities +Grouped convolution is fully supported alongside GEMM, with shared infrastructure to eliminate duplication. + +### Python + +```bash +# Generate grouped conv kernels +python3 codegen/unified_grouped_conv_codegen.py \ + --output-dir build/generated_kernels \ + --datatype fp16 --variant forward --ndim-spatial 2 + +# Build grouped conv examples +python3 scripts/compile_grouped_conv_examples.py examples/grouped_conv/cpp/01_basic_grouped_conv.cpp +``` + +### Key Files + +| Component | File | +|-----------|------| +| C++ Headers | `include/ck_tile/dispatcher/grouped_conv_*.hpp` | +| Python Codegen | `codegen/unified_grouped_conv_codegen.py` | +| Python Utils | `python/grouped_conv_utils.py` | +| Build Script | `scripts/compile_grouped_conv_examples.py` | +| Shared Codegen | `codegen/codegen_common.py` | +| Shared Utils | `python/dispatcher_common.py` | + +### Variants + +- **Forward** (`grouped_conv_fwd`) - Standard grouped convolution +- **Backward Data** (`grouped_conv_bwd_data`) - Gradient w.r.t. input +- **Backward Weight** (`grouped_conv_bwd_weight`) - Gradient w.r.t. weights + +### Shared Infrastructure + +GEMM and grouped convolution share common code to avoid duplication: +- `codegen/codegen_common.py` - TileConfig, TraitConfigBase, type mappings, parallel generation, arch-aware expansion +- `python/dispatcher_common.py` - Path helpers, validation, auto-correction, Colors, phased output --- diff --git a/dispatcher/bindings/README.md b/dispatcher/bindings/README.md index 7cda21f6ec..04029d32a9 100644 --- a/dispatcher/bindings/README.md +++ b/dispatcher/bindings/README.md @@ -6,13 +6,13 @@ This directory contains language bindings for the CK Tile Dispatcher. ``` bindings/ -├── ctypes/ # Python ctypes bindings (C API) -│ ├── gemm_ctypes_lib.cpp # GEMM dispatcher C API -│ ├── conv_ctypes_lib.cpp # Convolution dispatcher C API (fwd + bwd_data) -│ ├── conv_bwdw_ctypes_lib.cpp # Convolution backward weight C API -│ ├── gpu_helper.cpp # CLI helper for Python -│ └── CMakeLists.txt -└── README.md +|---- ctypes/ # Python ctypes bindings (C API) +| |---- gemm_ctypes_lib.cpp # GEMM dispatcher C API +| |---- conv_ctypes_lib.cpp # Grouped conv dispatcher C API (fwd + bwd_data) +| |---- conv_bwdw_ctypes_lib.cpp # Grouped conv backward weight C API (separate library) +| |---- gpu_helper.cpp # CLI helper for Python +| +---- CMakeLists.txt ++---- README.md ``` ## ctypes Bindings @@ -65,7 +65,7 @@ lib.dispatcher_cleanup() | `dispatcher_export_registry_json()` | Export registry as JSON | | `dispatcher_cleanup()` | Release resources | -### Convolution API +### Grouped Convolution API | Function | Description | |----------|-------------| @@ -105,5 +105,11 @@ Output is JSON for easy parsing: See the examples that use these bindings: - **GEMM**: `dispatcher/examples/gemm/python/` -- **Conv**: `dispatcher/examples/conv/python/` + +### Grouped Convolution + +Grouped convolution C++ headers and Python utilities are in: +- **C++ Headers**: `dispatcher/include/ck_tile/dispatcher/grouped_conv_*.hpp` +- **Python Utils**: `dispatcher/python/grouped_conv_utils.py` +- **Build Script**: `dispatcher/scripts/compile_grouped_conv_examples.py` diff --git a/dispatcher/bindings/ctypes/CMakeLists.txt b/dispatcher/bindings/ctypes/CMakeLists.txt index 804e5e9bd7..18314017f2 100644 --- a/dispatcher/bindings/ctypes/CMakeLists.txt +++ b/dispatcher/bindings/ctypes/CMakeLists.txt @@ -78,7 +78,7 @@ endif() # Look for forward kernels file(GLOB CONV_FWD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_fwd_*.hpp") # Look for backward data kernels -file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwdd_*.hpp") +file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwd_data_*.hpp") # Fallback: any conv kernel (for backwards compatibility) file(GLOB CONV_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_*.hpp") @@ -112,7 +112,7 @@ endif() # Add backward data kernel if available if(CONV_BWDD_KERNEL_HEADERS) list(GET CONV_BWDD_KERNEL_HEADERS 0 CONV_BWDD_KERNEL_HEADER) - message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWDD_KERNEL_HEADER}") + message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWD_DATA_KERNEL_HEADER}") target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_BWDD_KERNEL_HEADER}) target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_BWD_DATA_AVAILABLE) endif() diff --git a/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp b/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp index 09e058f80f..96b4aa3462 100644 --- a/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp +++ b/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp @@ -53,6 +53,7 @@ struct ConvBwdwProblemC int stride_d, stride_h, stride_w; int pad_d, pad_h, pad_w; int dilation_d, dilation_h, dilation_w; + int split_k; }; // ============================================================================= @@ -108,8 +109,7 @@ static float run_bwd_weight_impl(const void* input_ptr, grad_weight_ptr, // wei_ptr = grad_weight (output) {}, // ds_ptr grad_output_ptr, // out_ptr = grad_output - 1 // k_batch - ); + (prob->split_k > 1) ? prob->split_k : 1); ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; diff --git a/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp b/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp index d3c64621a7..002219c82e 100644 --- a/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp +++ b/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp @@ -1,128 +1,46 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT - -/** - * Convolution Dispatcher ctypes Library - * - * Provides C API for Python ctypes integration. - * Supports forward convolution. Backward operations require additional headers. - * - * REQUIRED: Forward kernel header must be force-included via -include flag. - * OPTIONAL: Backward kernels can be added with CONV_BWD_DATA_AVAILABLE/CONV_BWD_WEIGHT_AVAILABLE - * - * Usage from Python: - * lib = ctypes.CDLL("libdispatcher_conv.so") - * lib.conv_dispatcher_init() - * lib.conv_dispatcher_run(...) - */ +// +// Multi-kernel grouped convolution dispatcher for Python ctypes. +// +// Supports: forward / backward-data / backward-weight x 2D / 3D +// +// The dispatch header (conv_python_dispatch.hpp) is force-included via +// -include and brings in ALL compiled kernels with these aliases: +// +// 2D launchers (from include_all headers): +// SelectedConvKernelLauncher (forward 2D) +// SelectedConvBwdDataLauncher (backward-data 2D) +// SelectedConvBwdWeightLauncher (backward-weight 2D) +// +// 3D launchers (from dispatch header): +// ConvFwd3dLauncher (forward 3D) +// ConvBwdData3dLauncher (backward-data 3D) +// ConvBwdWeight3dLauncher (backward-weight 3D) +// +// Usage from Python: +// lib = ctypes.CDLL("libdispatcher_conv_lib.so") +// lib.conv_dispatcher_init() +// lib.conv_dispatcher_run(input, weight, output, &problem, stream) #include -#include -#include +#include #include -#include "ck_tile/dispatcher/conv_utils.hpp" #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" -using namespace ck_tile::dispatcher; - -// Global state (using shared_ptr for safe memory management) -static std::shared_ptr g_registry = nullptr; -static std::shared_ptr g_dispatcher = nullptr; -static std::vector g_kernels; - extern "C" { -// ============================================================================= -// Initialization -// ============================================================================= - -int conv_dispatcher_init() +// ========================================================================= +// Problem definition (matches Python ctypes struct exactly) +// ========================================================================= +enum ConvDirection { - if(g_registry) - return 0; // Already initialized - - g_registry = std::make_shared(); - g_dispatcher = std::make_shared(g_registry.get()); - - // Register kernel configurations using simple ConvKernelSet - // (actual kernel launch uses the force-included SelectedConvKernelLauncher) - using namespace ck_tile::dispatcher::conv_decl; - - // Forward kernels (required - must be force-included) - // Must match: conv_fwd_fp16_nhwgc_2d_compv4_cshuffle_intrawave_128x128x64_2x2x1_32x32x16_dsb - ConvKernelSet fwd_set; - fwd_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), - ConvAlgorithm() - .tile(128, 128, 64) // tile_m x tile_n x tile_k - .wave(2, 2, 1) - .warp(32, 32, 16) - .pipeline("compv4") - .scheduler("intrawave"), - "gfx942"); - g_registry->register_set(fwd_set, ConvRegistry::Priority::High); - -#ifdef CONV_BWD_DATA_AVAILABLE - // Backward data kernels - // Must match: conv_bwdd_fp16_nhwgc_2d_compv3_cshuffle_intrawave_128x128x64_2x2x1_32x32x16 - ConvKernelSet bwd_data_set; - bwd_data_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2), - ConvAlgorithm() - .tile(128, 128, 64) // tile_m x tile_n x tile_k - .wave(2, 2, 1) - .warp(32, 32, 16) - .pipeline("compv3") - .scheduler("intrawave"), - "gfx942"); - g_registry->register_set(bwd_data_set, ConvRegistry::Priority::High); -#endif - - return 0; -} - -int conv_dispatcher_cleanup() -{ - // shared_ptr automatically handles cleanup when reset - g_dispatcher.reset(); - g_registry.reset(); - g_kernels.clear(); - return 0; -} - -// ============================================================================= -// Registry Management -// ============================================================================= - -int conv_dispatcher_get_kernel_count() -{ - if(!g_registry) - return 0; - return static_cast(g_registry->size()); -} - -int conv_dispatcher_get_kernel_name(int index, char* buffer, int buffer_size) -{ - if(index < 0 || !buffer || buffer_size <= 0) - return -1; - - if(!g_registry) - return -1; - - // Use registry to get kernel names (they are registered with full names) - const auto& kernels = g_registry->all_kernels(); - if(static_cast(index) >= kernels.size()) - return -1; - - const auto* kernel = kernels[index]; - std::strncpy(buffer, kernel->name().c_str(), buffer_size - 1); - buffer[buffer_size - 1] = '\0'; - return 0; -} - -// ============================================================================= -// Problem Definition -// ============================================================================= + CONV_FORWARD = 0, + CONV_BWD_DATA = 1, + CONV_BWD_WEIGHT = 2 +}; struct ConvProblemC { @@ -132,267 +50,33 @@ struct ConvProblemC int stride_d, stride_h, stride_w; int pad_d, pad_h, pad_w; int dilation_d, dilation_h, dilation_w; - int direction; // 0=forward, 1=bwd_data, 2=bwd_weight + int direction; + int split_k; }; -// ============================================================================= -// Kernel Selection -// ============================================================================= +// ========================================================================= +// Initialization / lifecycle +// ========================================================================= +int conv_dispatcher_init() { return 0; } +int conv_dispatcher_cleanup() { return 0; } -int conv_dispatcher_is_supported(const ConvProblemC* prob) -{ - if(!g_registry || !prob) - return 0; - - ConvProblem problem; - problem.N = prob->N; - problem.G = prob->G; - problem.C = prob->C; - problem.K = prob->K; - problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; - problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; - problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; - problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; - problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; - problem.op = static_cast(prob->direction); - problem.compute_output_size(); - - const auto* kernel = g_dispatcher->select(problem); - return kernel ? 1 : 0; -} - -int conv_dispatcher_select_kernel(const ConvProblemC* prob, char* kernel_name, int buffer_size) -{ - if(!g_registry || !prob || !kernel_name || buffer_size <= 0) - return -1; - - ConvProblem problem; - problem.N = prob->N; - problem.G = prob->G; - problem.C = prob->C; - problem.K = prob->K; - problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; - problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; - problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; - problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; - problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; - problem.op = static_cast(prob->direction); - problem.compute_output_size(); - - const auto* kernel = g_dispatcher->select(problem); - if(!kernel) - return -1; - - std::strncpy(kernel_name, kernel->name().c_str(), buffer_size - 1); - kernel_name[buffer_size - 1] = '\0'; - - return 0; -} - -// ============================================================================= -// Convolution Execution -// ============================================================================= - -// Helper to build ConvParam -static ck_tile::conv::ConvParam build_conv_param(const ConvProblemC* prob) -{ - // Determine if this is 2D or 3D convolution - const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1); - - if(is_3d) - { - // 3D convolution: use all spatial dimensions - return ck_tile::conv::ConvParam{3, - prob->G, - prob->N, - prob->K, - prob->C, - {prob->filter_z, prob->filter_y, prob->filter_x}, - {prob->input_d, prob->input_h, prob->input_w}, - {prob->stride_d, prob->stride_h, prob->stride_w}, - {prob->dilation_d, prob->dilation_h, prob->dilation_w}, - {prob->pad_d, prob->pad_h, prob->pad_w}, - {prob->pad_d, prob->pad_h, prob->pad_w}}; - } - else - { - // 2D convolution: only use H, W dimensions - return ck_tile::conv::ConvParam{2, - prob->G, - prob->N, - prob->K, - prob->C, - {prob->filter_y, prob->filter_x}, - {prob->input_h, prob->input_w}, - {prob->stride_h, prob->stride_w}, - {prob->dilation_h, prob->dilation_w}, - {prob->pad_h, prob->pad_w}, - {prob->pad_h, prob->pad_w}}; - } -} - -// Forward convolution (required - kernel header must be force-included) -static float run_forward(const void* input_ptr, - const void* weight_ptr, - void* output_ptr, - const ConvProblemC* prob, - void* stream) -{ - auto conv_param = build_conv_param(prob); - - ck_tile::GroupedConvFwdHostArgs<> args(conv_param, input_ptr, weight_ptr, {}, output_ptr, 1); - - ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; - - // SelectedConvKernelLauncher is defined in the force-included forward kernel header - return SelectedConvKernelLauncher::launch(args, stream_cfg); -} - -#ifdef CONV_BWD_DATA_AVAILABLE -// Backward data convolution (optional) -// Computes: grad_input = conv_bwd_data(weight, grad_output) -// -// Parameters: -// grad_output_ptr: dY - gradient from next layer (const, read-only INPUT) -// weight_ptr: W - frozen weights (const, read-only INPUT) -// grad_input_ptr: dX - gradient for input (writable, OUTPUT) -static float run_bwd_data(const void* grad_output_ptr, - const void* weight_ptr, - void* grad_input_ptr, - const ConvProblemC* prob, - void* stream) -{ - auto conv_param = build_conv_param(prob); - - // CK Tile API uses tensor POSITION names (from forward pass), not data flow: - // in_ptr = input tensor position = grad_input_ptr (dX, OUTPUT of bwd_data) - // wei_ptr = weight tensor = weight_ptr (W, const) - // out_ptr = output tensor position = grad_output_ptr (dY, INPUT to bwd_data) - ck_tile::GroupedConvBwdDataHostArgs args( - conv_param, grad_input_ptr, weight_ptr, {}, grad_output_ptr, 1); - - ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; - - return SelectedConvBwdDataLauncher::launch(args, stream_cfg); -} -#endif - -#ifdef CONV_BWD_WEIGHT_AVAILABLE -// Backward weight convolution (optional) -// Parameters: -// input_ptr: original forward input X (const, read-only) -// grad_output_ptr: gradient from next layer dY (const, read-only) -// grad_weight_ptr: gradient of weights dW (writable, OUTPUT) -static float run_bwd_weight(const void* input_ptr, - const void* grad_output_ptr, - void* grad_weight_ptr, - const ConvProblemC* prob, - void* stream) -{ - auto conv_param = build_conv_param(prob); - - // GroupedConvBwdWeightHostArgs constructor order: - // (param, in=X, wei=dW (output), ds, out=dY (input), k_batch) - // Note: wei_ptr is the OUTPUT (grad_weight), out_ptr is the INPUT (grad_output) - ck_tile::GroupedConvBwdWeightHostArgs args( - conv_param, input_ptr, grad_weight_ptr, {}, grad_output_ptr, 1); - - ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; - - return SelectedConvBwdWeightLauncher::launch(args, stream_cfg); -} -#endif - -/** - * @brief Execute convolution based on direction specified in prob - * - * Parameter mapping varies by direction: - * Forward (direction=0): - * input_ptr = X (input tensor) - * weight_ptr = W (weight tensor) - * output_ptr = Y (output buffer) - * - * Backward Data (direction=1): - * input_ptr = dY (grad_output - gradient from next layer) - * weight_ptr = W (weight tensor, frozen) - * output_ptr = dX (grad_input buffer) - * - * Backward Weight (direction=2): - * input_ptr = X (forward input tensor) - * weight_ptr = dY (grad_output - gradient from next layer) - * output_ptr = dW (grad_weight buffer) - */ -float conv_dispatcher_run(const void* input_ptr, - const void* weight_ptr, - void* output_ptr, - const ConvProblemC* prob, - void* stream) -{ - // Validate all required pointers before kernel launch - if(!g_dispatcher || !prob) - return -1.0f; - if(!input_ptr || !weight_ptr || !output_ptr) - return -1.0f; // Null data pointer would cause kernel crash - - // Build problem for kernel selection - ConvProblem problem; - problem.N = prob->N; - problem.G = prob->G; - problem.C = prob->C; - problem.K = prob->K; - problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; - problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; - problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; - problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; - problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; - problem.op = static_cast(prob->direction); - problem.compute_output_size(); - - // Select kernel - const auto* kernel = g_dispatcher->select(problem); - if(!kernel) - return -1.0f; - - // Dispatch based on direction - switch(prob->direction) - { - case 0: // Forward (always available) - return run_forward(input_ptr, weight_ptr, output_ptr, prob, stream); - -#ifdef CONV_BWD_DATA_AVAILABLE - case 1: // Backward data - // Convention: caller passes (grad_output, weight, grad_input_buffer) - // in the (input_ptr, weight_ptr, output_ptr) slots respectively. - // run_bwd_data expects: (grad_output, weight, grad_input) - return run_bwd_data(input_ptr, weight_ptr, output_ptr, prob, stream); -#endif - -#ifdef CONV_BWD_WEIGHT_AVAILABLE - case 2: // Backward weight - // Convention: caller passes (input, grad_output, grad_weight_buffer) - // in the (input_ptr, weight_ptr, output_ptr) slots respectively. - // run_bwd_weight expects: (input, grad_output, grad_weight) - return run_bwd_weight(input_ptr, weight_ptr, output_ptr, prob, stream); -#endif - - default: return -1.0f; - } -} - -// ============================================================================= -// Info -// ============================================================================= - -const char* conv_dispatcher_version() { return "1.0.0"; } +// ========================================================================= +// Library info +// ========================================================================= +const char* conv_dispatcher_version() { return "2.0.0"; } int conv_dispatcher_has_kernels() { - return 1; // Forward kernel is required +#if defined(CONV_FWD_2D_AVAILABLE) || defined(CONV_FWD_3D_AVAILABLE) + return 1; +#else + return 0; +#endif } int conv_dispatcher_has_bwd_data() { -#ifdef CONV_BWD_DATA_AVAILABLE +#if defined(CONV_BWD_DATA_2D_AVAILABLE) || defined(CONV_BWD_DATA_3D_AVAILABLE) return 1; #else return 0; @@ -401,11 +85,240 @@ int conv_dispatcher_has_bwd_data() int conv_dispatcher_has_bwd_weight() { -#ifdef CONV_BWD_WEIGHT_AVAILABLE +#if defined(CONV_BWD_WEIGHT_2D_AVAILABLE) || defined(CONV_BWD_WEIGHT_3D_AVAILABLE) return 1; #else return 0; #endif } +int conv_dispatcher_get_kernel_count() +{ + return CONV_KERNEL_COUNT; // defined in conv_python_dispatch.hpp +} + +int conv_dispatcher_get_kernel_name(int index, char* buffer, int buffer_size) +{ + if(!buffer || buffer_size <= 0 || index < 0 || index >= CONV_KERNEL_COUNT) + return -1; + std::strncpy(buffer, CONV_KERNEL_NAMES[index], buffer_size - 1); + buffer[buffer_size - 1] = '\0'; + return 0; +} + +// ========================================================================= +// Support query +// ========================================================================= +bool conv_dispatcher_is_supported(const ConvProblemC* prob) +{ + if(!prob) + return false; + const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1); + switch(prob->direction) + { + case CONV_FORWARD: +#if defined(CONV_FWD_3D_AVAILABLE) + if(is_3d) + return true; +#endif +#if defined(CONV_FWD_2D_AVAILABLE) + if(!is_3d) + return true; +#endif + return false; + case CONV_BWD_DATA: +#if defined(CONV_BWD_DATA_3D_AVAILABLE) + if(is_3d) + return true; +#endif +#if defined(CONV_BWD_DATA_2D_AVAILABLE) + if(!is_3d) + return true; +#endif + return false; + case CONV_BWD_WEIGHT: +#if defined(CONV_BWD_WEIGHT_3D_AVAILABLE) + if(is_3d) + return true; +#endif +#if defined(CONV_BWD_WEIGHT_2D_AVAILABLE) + if(!is_3d) + return true; +#endif + return false; + default: return false; + } +} + +// ========================================================================= +// ConvParam builders +// ========================================================================= +static ck_tile::conv::ConvParam make_param_2d(const ConvProblemC* p) +{ + return ck_tile::conv::ConvParam{2, + p->G, + p->N, + p->K, + p->C, + {p->filter_y, p->filter_x}, + {p->input_h, p->input_w}, + {p->stride_h, p->stride_w}, + {p->dilation_h, p->dilation_w}, + {p->pad_h, p->pad_w}, + {p->pad_h, p->pad_w}}; +} + +static ck_tile::conv::ConvParam make_param_3d(const ConvProblemC* p) +{ + return ck_tile::conv::ConvParam{3, + p->G, + p->N, + p->K, + p->C, + {p->filter_z, p->filter_y, p->filter_x}, + {p->input_d, p->input_h, p->input_w}, + {p->stride_d, p->stride_h, p->stride_w}, + {p->dilation_d, p->dilation_h, p->dilation_w}, + {p->pad_d, p->pad_h, p->pad_w}, + {p->pad_d, p->pad_h, p->pad_w}}; +} + +// ========================================================================= +// Kernel launch helpers +// ========================================================================= + +#ifdef CONV_FWD_2D_AVAILABLE +static float +launch_fwd_2d(const void* in, const void* wei, void* out, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_2d(p); + ck_tile::GroupedConvFwdHostArgs<> args(param, in, wei, {}, out, 1); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return SelectedConvKernelLauncher::launch(args, sc); +} +#endif + +#ifdef CONV_FWD_3D_AVAILABLE +static float +launch_fwd_3d(const void* in, const void* wei, void* out, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_3d(p); + ck_tile::GroupedConvFwdHostArgs<> args(param, in, wei, {}, out, 1); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return ConvFwd3dLauncher::launch(args, sc); +} +#endif + +#ifdef CONV_BWD_DATA_2D_AVAILABLE +static float launch_bwd_data_2d( + const void* dy, const void* wei, void* dx, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_2d(p); + ck_tile::GroupedConvBwdDataHostArgs args(param, dx, wei, {}, dy, 1); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return SelectedConvBwdDataLauncher::launch(args, sc); +} +#endif + +#ifdef CONV_BWD_DATA_3D_AVAILABLE +static float launch_bwd_data_3d( + const void* dy, const void* wei, void* dx, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_3d(p); + ck_tile::GroupedConvBwdDataHostArgs args(param, dx, wei, {}, dy, 1); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return ConvBwdData3dLauncher::launch(args, sc); +} +#endif + +#ifdef CONV_BWD_WEIGHT_2D_AVAILABLE +static float launch_bwd_weight_2d( + const void* x, const void* dy, void* dw, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_2d(p); + const int k_batch = (p->split_k > 1) ? p->split_k : 1; + ck_tile::GroupedConvBwdWeightHostArgs args(param, x, dw, {}, dy, k_batch); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return SelectedConvBwdWeightLauncher::launch(args, sc); +} +#endif + +#ifdef CONV_BWD_WEIGHT_3D_AVAILABLE +static float launch_bwd_weight_3d( + const void* x, const void* dy, void* dw, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_3d(p); + const int k_batch = (p->split_k > 1) ? p->split_k : 1; + ck_tile::GroupedConvBwdWeightHostArgs args(param, x, dw, {}, dy, k_batch); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return ConvBwdWeight3dLauncher::launch(args, sc); +} +#endif + +// ========================================================================= +// Main dispatch +// +// direction=0 (forward): a=X(input), b=W(weight), c=Y(output) +// direction=1 (bwd_data): a=dY(grad_out), b=W(weight), c=dX(grad_in) +// direction=2 (bwd_weight): a=X(input), b=dY(grad_out), c=dW(grad_wei) +// ========================================================================= +float conv_dispatcher_run( + const void* a_ptr, const void* b_ptr, void* c_ptr, const ConvProblemC* prob, void* stream) +{ + if(!prob || !a_ptr || !b_ptr || !c_ptr) + return -1.0f; + + const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1); + hipStream_t hip_stream = static_cast(stream); + + try + { + switch(prob->direction) + { + case CONV_FORWARD: +#ifdef CONV_FWD_3D_AVAILABLE + if(is_3d) + return launch_fwd_3d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif +#ifdef CONV_FWD_2D_AVAILABLE + if(!is_3d) + return launch_fwd_2d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif + return -2.0f; + + case CONV_BWD_DATA: +#ifdef CONV_BWD_DATA_3D_AVAILABLE + if(is_3d) + return launch_bwd_data_3d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif +#ifdef CONV_BWD_DATA_2D_AVAILABLE + if(!is_3d) + return launch_bwd_data_2d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif + return -2.0f; + + case CONV_BWD_WEIGHT: +#ifdef CONV_BWD_WEIGHT_3D_AVAILABLE + if(is_3d) + return launch_bwd_weight_3d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif +#ifdef CONV_BWD_WEIGHT_2D_AVAILABLE + if(!is_3d) + return launch_bwd_weight_2d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif + return -2.0f; + + default: return -1.0f; + } + } + catch(const std::exception&) + { + return -3.0f; // Kernel rejected args (e.g. unsupported tile/channel combo) + } + catch(...) + { + return -3.0f; + } +} + } // extern "C" diff --git a/dispatcher/codegen/ADDING_NEW_GPU.md b/dispatcher/codegen/ADDING_NEW_GPU.md index 0bd2966a85..664b59b6b1 100644 --- a/dispatcher/codegen/ADDING_NEW_GPU.md +++ b/dispatcher/codegen/ADDING_NEW_GPU.md @@ -9,8 +9,8 @@ Guide for adding support for a new AMD GPU architecture to the CK Tile Dispatche The dispatcher uses `arch_specs.json` as the **single source of truth** for GPU specifications: ``` -arch_specs.json → generate_arch_specs.py → arch_specs_generated.py (Python) - → arch_specs_generated.hpp (C++) +arch_specs.json -> generate_arch_specs.py -> arch_specs_generated.py (Python) + -> arch_specs_generated.hpp (C++) ``` ## Quick Start @@ -175,14 +175,14 @@ for error in result.errors: ``` codegen/ -├── arch_specs.json # Single source of truth (EDIT THIS) -├── generate_arch_specs.py # Generator script -├── arch_specs_generated.py # Generated Python module -└── ADDING_NEW_GPU.md # This file +|---- arch_specs.json # Single source of truth (EDIT THIS) +|---- generate_arch_specs.py # Generator script +|---- arch_specs_generated.py # Generated Python module ++---- ADDING_NEW_GPU.md # This file include/ck_tile/dispatcher/ -├── arch_specs_generated.hpp # Generated C++ header -└── arch_filter.hpp # C++ filter +|---- arch_specs_generated.hpp # Generated C++ header ++---- arch_filter.hpp # C++ filter ``` ## Best Practices diff --git a/dispatcher/codegen/README.md b/dispatcher/codegen/README.md index 2d753924f5..40a9b7b8c1 100644 --- a/dispatcher/codegen/README.md +++ b/dispatcher/codegen/README.md @@ -1,11 +1,22 @@ -# CK Tile GEMM Unified Code Generator +# CK Tile Unified Code Generators -Single source of truth for all GEMM kernel generation. +Single source of truth for GEMM and Grouped Convolution kernel generation. > **See also:** [Main Dispatcher README](../README.md) for installation and core concepts. +## Shared Infrastructure + +Both GEMM and Grouped Conv generators share common code via `codegen_common.py`: +- `TileConfig` - Dataclass for tile dimensions +- `TraitConfigBase` - Base for kernel trait configurations with arch-aware validation +- `CommonTypeMappings` - Dtype-to-C++ type mappings +- `parallel_generate()` - Parallel kernel generation with per-kernel progress logging +- Arch-aware expansion helpers (`valid_wave_configs`, `valid_warp_configs`, etc.) + ## Quick Start +### GEMM + ```bash cd dispatcher/codegen @@ -22,6 +33,25 @@ python3 unified_gemm_codegen.py \ --variants standard preshuffle multi_d ``` +### Grouped Convolution + +```bash +cd dispatcher/codegen + +# Generate forward FP16 grouped conv kernels +python3 unified_grouped_conv_codegen.py \ + --output-dir ../build/generated_kernels \ + --datatype fp16 \ + --variant forward \ + --ndim-spatial 2 + +# Generate backward data kernels +python3 unified_grouped_conv_codegen.py \ + --output-dir ../build/generated_kernels \ + --variant backward_data \ + --ndim-spatial 2 +``` + ## Using from Python ```python @@ -58,13 +88,13 @@ results = codegen.generate_all() ## Variants ### Standard -Basic GEMM: `C = A × B` +Basic GEMM: `C = A x B` ### PreShuffle Optimized weight access with LDS pre-shuffling. Best for large matrices. ### Multi-D -Element-wise fusion: `C = op(A × B + D0 + D1 + ...)` +Element-wise fusion: `C = op(A x B + D0 + D1 + ...)` Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh` @@ -72,10 +102,11 @@ Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh` ``` generated_kernels/ -├── gemm_fp16_rcr_compv4_..._128x128x32_....hpp -├── gemm_fp16_rcr_compv4_..._preshuffle.hpp -├── gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp -└── ... +|---- gemm_fp16_rcr_compv4_..._128x128x32_....hpp # GEMM kernels +|---- gemm_fp16_rcr_compv4_..._preshuffle.hpp +|---- gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp +|---- grouped_conv_fwd_fp16_nhwgc_..._128x128x32_....hpp # Grouped conv kernels ++---- ... ``` ## Configuration Files diff --git a/dispatcher/codegen/codegen_common.py b/dispatcher/codegen/codegen_common.py new file mode 100644 index 0000000000..4e9e8de1b3 --- /dev/null +++ b/dispatcher/codegen/codegen_common.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Shared codegen infrastructure for GEMM and grouped convolution code generators. + +Extracted from unified_gemm_codegen.py + arch-aware expansion helpers from conv. +Both unified_gemm_codegen.py and unified_grouped_conv_codegen.py import from here +to eliminate duplication. +""" + +import logging +import concurrent.futures +from dataclasses import dataclass +from typing import ( + Callable, + ClassVar, + Dict, + FrozenSet, + List, + Optional, + Sequence, + Tuple, + TypeVar, +) + +log = logging.getLogger(__name__) + +T = TypeVar("T") +R = TypeVar("R") + +ANY_INT = -1 + + +# ============================================================================ +# Tile and Trait Configuration (shared between GEMM and Conv) +# ============================================================================ + + +@dataclass +class TileConfig: + """Tile configuration parameters shared by GEMM and grouped conv.""" + + tile_m: int + tile_n: int + tile_k: int + warp_m: int + warp_n: int + warp_k: int + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + + def is_valid(self) -> bool: + if self.tile_m <= 0 or self.tile_n <= 0 or self.tile_k <= 0: + return False + return ( + self.tile_m % (self.warp_m * self.warp_tile_m) == 0 + and self.tile_n % (self.warp_n * self.warp_tile_n) == 0 + and self.tile_k % (self.warp_k * self.warp_tile_k) == 0 + ) + + +@dataclass +class TraitConfigBase: + """ + Base kernel trait configuration shared by GEMM and grouped conv. + + GEMM extends this with ``persistent``; grouped conv extends with + ``double_smem_buffer`` and ``num_groups_to_merge``. + """ + + pipeline: str # mem, compv3, compv4, compv5, ... + epilogue: str # cshuffle, default + scheduler: str # intrawave, interwave + pad_m: bool + pad_n: bool + pad_k: bool + + # Unsupported (pipeline, epilogue, scheduler) combinations. + # Only 'mem' and 'basic_v1' pipelines support interwave; all compute + # pipelines (compv3/v4/v5/v6/async) only support intrawave. + _UNSUPPORTED: ClassVar[FrozenSet] = frozenset( + { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + ("compv5", "cshuffle", "interwave"), + ("compv5", "default", "interwave"), + ("compv6", "cshuffle", "interwave"), + ("compv6", "default", "interwave"), + ("comp_async", "cshuffle", "interwave"), + ("comp_async", "default", "interwave"), + ("basic_async_v1", "cshuffle", "interwave"), + ("basic_async_v1", "default", "interwave"), + } + ) + + def is_valid(self) -> bool: + return (self.pipeline, self.epilogue, self.scheduler) not in self._UNSUPPORTED + + +# ============================================================================ +# Type Mappings (centralized for both GEMM and conv codegen) +# ============================================================================ + + +class CommonTypeMappings: + """Centralized type mappings shared by GEMM and grouped conv codegen.""" + + DTYPE_TO_CK = { + "fp16": "fp16_t", + "bf16": "bf16_t", + "fp32": "float", + "fp8": "fp8_t", + "bf8": "bf8_t", + "int8": "int8_t", + } + + DTYPE_TO_CK_QUALIFIED = { + "fp16": "ck_tile::fp16_t", + "bf16": "ck_tile::bf16_t", + "fp32": "float", + "fp8": "ck_tile::fp8_t", + "bf8": "ck_tile::bf8_t", + "int8": "int8_t", + } + + DTYPE_TO_DISPATCHER = { + "fp16": "DataType::FP16", + "bf16": "DataType::BF16", + "fp32": "DataType::FP32", + "fp8": "DataType::FP8", + "bf8": "DataType::BF8", + "int8": "DataType::INT8", + } + + # GEMM-specific layout mappings ("r"/"c" for row/column major). + # Convolution layouts (NHWGC, GKYXC, etc.) are handled by + # unified_grouped_conv_codegen.py via GroupedConvLayout / GroupedConvTypeMappings. + GEMM_LAYOUT_TO_CK = { + "r": "tensor_layout::gemm::RowMajor", + "c": "tensor_layout::gemm::ColumnMajor", + } + LAYOUT_TO_CK = GEMM_LAYOUT_TO_CK # backward compat alias + + GEMM_LAYOUT_TO_DISPATCHER = { + "r": "LayoutTag::RowMajor", + "c": "LayoutTag::ColMajor", + } + LAYOUT_TO_DISPATCHER = GEMM_LAYOUT_TO_DISPATCHER # backward compat alias + + # GEMM-only pipeline mappings (used by unified_gemm_codegen.py). + # Convolution pipelines are in GroupedConvTypeMappings + # (unified_grouped_conv_codegen.py). CK Tile conv supports: + # BASIC_V1, Mem, CompV3, CompV4, CompV5, CompV6, ASYNC_V1, ASYNC_V4. + # The dispatcher currently generates: mem, compv3, compv4. + # preshufflev2 is GEMM-only (weight pre-shuffle for GEMM, not conv). + PIPELINE_TO_CK = { + "mem": "GemmPipelineAgBgCrMem", + "compv3": "GemmPipelineAgBgCrCompV3", + "compv4": "GemmPipelineAgBgCrCompV4", + "compv5": "GemmPipelineAgBgCrCompV5", + "preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2", + } + + PIPELINE_TO_BASE = { + "mem": "BaseGemmPipelineAgBgCrMem", + "compv3": "BaseGemmPipelineAgBgCrCompV3", + "compv4": "BaseGemmPipelineAgBgCrCompV4", + "compv5": "BaseGemmPipelineAgBgCrCompV5", + "preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2", + } + + PIPELINE_TO_DISPATCHER = { + "mem": "Pipeline::Mem", + "compv3": "Pipeline::CompV3", + "compv4": "Pipeline::CompV4", + "compv5": "Pipeline::CompV5", + "preshufflev2": "Pipeline::PreShuffleV2", + } + + SCHEDULER_TO_CK = { + "intrawave": "GemmPipelineScheduler::Intrawave", + "interwave": "GemmPipelineScheduler::Interwave", + "default": "GemmPipelineScheduler::Default", + } + + SCHEDULER_TO_DISPATCHER = { + "intrawave": "Scheduler::Intrawave", + "interwave": "Scheduler::Interwave", + "default": "Scheduler::Auto", + } + + EPILOGUE_TO_DISPATCHER = { + "cshuffle": "Epilogue::CShuffle", + "default": "Epilogue::Default", + } + + @staticmethod + def get_output_dtype(dtype: str) -> str: + """Get output datatype (fp8/bf8 -> fp16).""" + return "fp16" if dtype in ("fp8", "bf8") else dtype + + +# ============================================================================ +# Code Generation Helpers +# ============================================================================ + + +def generate_cpp_compilation_unit(kernel_name: str) -> str: + """Generate a .cpp compilation unit that includes a kernel header. + + This is the standard pattern: one .cpp per kernel that just includes + the generated .hpp header, causing template instantiation. + """ + return ( + f"// Auto-generated compilation unit for {kernel_name}\n" + f'#include "{kernel_name}.hpp"\n' + ) + + +def parallel_generate( + generate_fn: Callable[[T], R], + items: Sequence[T], + parallel: bool = True, +) -> List[R]: + """Run ``generate_fn`` over ``items``, optionally in parallel. + + Logs per-item progress (best-of-conv pattern). + Returns a flat list of results in completion order. + """ + results: List[R] = [] + if not items: + return results + + if parallel and len(items) > 1: + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = {executor.submit(generate_fn, item): item for item in items} + for future in concurrent.futures.as_completed(futures): + result = future.result() + results.append(result) + log.info("Generated: %s", futures[future]) + else: + for item in items: + result = generate_fn(item) + results.append(result) + log.info("Generated: %s", item) + + return results + + +# ============================================================================ +# Arch-Aware Expansion Helpers (adopted from conv kernel_decl.hpp) +# ============================================================================ + +# These load from arch_specs_generated when available, falling back to +# hardcoded defaults that match the most common arch (gfx942). + +_arch_data_cache: Optional[Dict] = None + + +def _get_arch_data() -> Dict: + """Load arch filter data, with caching.""" + global _arch_data_cache + if _arch_data_cache is not None: + return _arch_data_cache + + try: + from arch_specs_generated import ( + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + TRAIT_UNSUPPORTED_COMBINATIONS, + get_supported_archs, + ) + + _arch_data_cache = { + "warp_combos": WARP_SUPPORTED_COMBINATIONS, + "warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS, + "trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS, + "supported_archs": get_supported_archs(), + } + except ImportError: + _arch_data_cache = { + "warp_combos": { + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + }, + "warp_tile_combos": { + "gfx942": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]}, + "gfx90a": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]}, + }, + "trait_unsupported": { + ("compv3", "cshuffle", "interwave"), + ("compv4", "cshuffle", "interwave"), + }, + "supported_archs": ["gfx90a", "gfx942", "gfx950"], + } + return _arch_data_cache + + +def valid_wave_configs(arch: str) -> List[List[int]]: + """Return valid [wave_m, wave_n, wave_k] combos for *arch*.""" + data = _get_arch_data() + return data["warp_combos"].get(arch, [[2, 2, 1]]) + + +def valid_warp_configs(arch: str, dtype: str) -> List[List[int]]: + """Return valid [warp_tile_m, warp_tile_n, warp_tile_k] combos for *arch*/*dtype*. + + The dtype key is constructed as ``{dtype}_{dtype}_{acc}`` where acc is + fp32 for float types and int32 for int8. + """ + data = _get_arch_data() + acc = "int32" if dtype == "int8" else "fp32" + dtype_key = f"{dtype}_{dtype}_{acc}" + arch_tiles = data["warp_tile_combos"].get(arch, {}) + return arch_tiles.get(dtype_key, [[32, 32, 16]]) + + +def valid_trait_configs() -> List[Tuple[str, str]]: + """Return valid (pipeline, scheduler) pairs. + + Compute pipelines only support intrawave; mem supports both. + """ + return [ + ("compv3", "intrawave"), + ("compv4", "intrawave"), + ("compv5", "intrawave"), + ("mem", "intrawave"), + ("mem", "interwave"), + ] + + +def needs_wave_expansion(config: dict) -> bool: + """True if wave_m or wave_n is a wildcard (ANY_INT = -1).""" + return config.get("wave_m", 2) == ANY_INT or config.get("wave_n", 2) == ANY_INT + + +def needs_warp_expansion(config: dict) -> bool: + """True if warp_m or warp_n is a wildcard (ANY_INT = -1).""" + return config.get("warp_m", 32) == ANY_INT or config.get("warp_n", 32) == ANY_INT + + +def needs_pipeline_expansion(config: dict) -> bool: + """True if pipeline is a wildcard (\"*\").""" + return config.get("pipeline", "compv4") == "*" diff --git a/dispatcher/codegen/generate_dispatcher_registration.py b/dispatcher/codegen/generate_dispatcher_registration.py index 024ec4a7c8..8e8b67376c 100644 --- a/dispatcher/codegen/generate_dispatcher_registration.py +++ b/dispatcher/codegen/generate_dispatcher_registration.py @@ -109,7 +109,7 @@ inline void register_all_kernels() """ output_file.write_text(content) - print(f"✓ Generated registration header: {output_file}") + print(f"OK Generated registration header: {output_file}") def generate_registration_cpp(kernels: List[KernelConfig], output_file: Path): @@ -143,7 +143,7 @@ namespace generated { """ output_file.write_text(content) - print(f"✓ Generated registration implementation: {output_file}") + print(f"OK Generated registration implementation: {output_file}") def generate_kernel_wrapper_header(kernel: KernelConfig, output_dir: Path): @@ -414,8 +414,8 @@ def main(): with open(manifest_output, "w") as f: json.dump(manifest_data, f, indent=2) - print(f"✓ Generated manifest: {manifest_output}") - print("\n✓ Registration code generation complete!") + print(f"OK Generated manifest: {manifest_output}") + print("\nOK Registration code generation complete!") print(f" Total kernels: {len(kernels)}") print(" Output files:") print(f" - {registration_header}") diff --git a/dispatcher/codegen/generate_kernel_wrappers.py b/dispatcher/codegen/generate_kernel_wrappers.py index 53a9bff3ed..e11bd7a0a5 100644 --- a/dispatcher/codegen/generate_kernel_wrappers.py +++ b/dispatcher/codegen/generate_kernel_wrappers.py @@ -17,10 +17,10 @@ Usage: Output structure: build/kernel_wrappers/ - ├── gemm_fp16_rcr_128x128x32.cpp - ├── gemm_fp16_rcr_256x256x64.cpp - ├── conv_fwd_fp16_2d_128x128.cpp - └── ... + |---- gemm_fp16_rcr_128x128x32.cpp + |---- gemm_fp16_rcr_256x256x64.cpp + |---- conv_fwd_fp16_2d_128x128.cpp + +---- ... Each .cpp simply includes its corresponding .hpp and forces symbol emission. """ diff --git a/dispatcher/codegen/kernel_config_loader.py b/dispatcher/codegen/kernel_config_loader.py index 537fc40581..980b4e5fd0 100644 --- a/dispatcher/codegen/kernel_config_loader.py +++ b/dispatcher/codegen/kernel_config_loader.py @@ -359,8 +359,8 @@ class ConvTraitConfig: @dataclass -class ConvKernelConfig: - """Complete convolution kernel configuration""" +class GroupedConvKernelConfig: + """Complete grouped convolution kernel configuration""" tile: ConvTileConfig = field(default_factory=ConvTileConfig) trait: ConvTraitConfig = field(default_factory=ConvTraitConfig) @@ -419,7 +419,11 @@ class ConvKernelConfig: def kernel_name(self) -> str: """Generate kernel name from config""" - variant_map = {"forward": "fwd", "bwd_data": "bwdd", "bwd_weight": "bwdw"} + variant_map = { + "forward": "fwd", + "bwd_data": "bwd_data", + "bwd_weight": "bwd_weight", + } var_str = variant_map.get(self.variant, self.variant) name = f"conv_{var_str}_{self.dtype_input}_{self.ndim}d" @@ -433,11 +437,11 @@ class ConvKernelConfig: @dataclass -class ConvKernelConfigSet: +class GroupedConvKernelConfigSet: """A set of convolution kernel configurations loaded from JSON""" name: str = "default" - configs: List[ConvKernelConfig] = field(default_factory=list) + configs: List[GroupedConvKernelConfig] = field(default_factory=list) # Tile parameter ranges tile_m_values: List[int] = field(default_factory=lambda: [128]) @@ -481,7 +485,7 @@ class ConvKernelConfigSet: layout: str = "nhwgc" gpu_targets: List[str] = field(default_factory=lambda: ["gfx942"]) - def generate_configs(self) -> Iterator[ConvKernelConfig]: + def generate_configs(self) -> Iterator[GroupedConvKernelConfig]: """Generate all kernel configurations (cartesian product)""" # Tile parameters tile_params = itertools.product( @@ -548,7 +552,7 @@ class ConvKernelConfigSet: double_smem_buffer=trait[6], num_groups_to_merge=trait[7], ) - yield ConvKernelConfig( + yield GroupedConvKernelConfig( tile=tile_cfg, trait=trait_cfg, dtype_input=self.dtype_input, @@ -599,7 +603,9 @@ class ConvKernelConfigSet: return tile_count * trait_count * extra_count * len(self.gpu_targets) -def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet: +def load_grouped_conv_kernel_configs( + json_path: str | Path, +) -> GroupedConvKernelConfigSet: """ Load convolution kernel configurations from a JSON file. @@ -607,14 +613,14 @@ def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet: json_path: Path to JSON configuration file Returns: - ConvKernelConfigSet with all parameter values loaded + GroupedConvKernelConfigSet with all parameter values loaded """ json_path = Path(json_path) with open(json_path) as f: data = json.load(f) - config_set = ConvKernelConfigSet() + config_set = GroupedConvKernelConfigSet() # Name config_set.name = data.get("kernel_set_name", json_path.stem) @@ -680,15 +686,15 @@ def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet: def generate_cpp_conv_kernel_set_declaration( - config_set: ConvKernelConfigSet, + config_set: GroupedConvKernelConfigSet, set_name: Optional[str] = None, ) -> str: """ - Generate C++ DECL_CONV_KERNEL_SET code from a ConvKernelConfigSet. + Generate C++ DECL_GROUPED_CONV_KERNEL_SET code from a GroupedConvKernelConfigSet. """ name = set_name or config_set.name - lines = [f"DECL_CONV_KERNEL_SET({name},"] + lines = [f"DECL_GROUPED_CONV_KERNEL_SET({name},"] for config in config_set.generate_configs(): line = f' .add("{config.dtype_input}", "{config.variant}", {config.ndim}, ' diff --git a/dispatcher/codegen/unified_gemm_codegen.py b/dispatcher/codegen/unified_gemm_codegen.py index b0dd961be7..a818cec83e 100755 --- a/dispatcher/codegen/unified_gemm_codegen.py +++ b/dispatcher/codegen/unified_gemm_codegen.py @@ -7,7 +7,7 @@ Unified GEMM Code Generator - Single Source of Truth This is THE unified code generator for all GEMM kernel variants: -- Standard GEMM (C = A × B) +- Standard GEMM (C = A x B) - Preshuffle GEMM (optimized weight access) - Multi-D GEMM (element-wise fusion) @@ -25,6 +25,12 @@ from dataclasses import dataclass, asdict from enum import Enum import concurrent.futures +from codegen_common import ( + TileConfig, + TraitConfigBase, + CommonTypeMappings as TypeMappings, +) + # Import architecture filter for GPU-specific validation try: from arch_filter import ArchFilter, KernelConfig as ArchKernelConfig, OperatorType @@ -194,62 +200,14 @@ class GemmVariant(Enum): MULTI_D = "multi_d" -@dataclass -class TileConfig: - """Tile configuration parameters""" - - tile_m: int - tile_n: int - tile_k: int - warp_m: int - warp_n: int - warp_k: int - warp_tile_m: int - warp_tile_n: int - warp_tile_k: int - - def is_valid(self) -> bool: - """Validate tile configuration""" - return ( - self.tile_m % (self.warp_m * self.warp_tile_m) == 0 - and self.tile_n % (self.warp_n * self.warp_tile_n) == 0 - and self.tile_k % (self.warp_k * self.warp_tile_k) == 0 - and self.tile_m > 0 - and self.tile_n > 0 - and self.tile_k > 0 - ) +# TileConfig imported from codegen_common @dataclass -class TraitConfig: - """Kernel trait configuration""" +class TraitConfig(TraitConfigBase): + """GEMM-specific trait configuration extending TraitConfigBase with persistent mode.""" - pipeline: str # mem, compv3, compv4 - epilogue: str # default, cshuffle - scheduler: str # intrawave, interwave - pad_m: bool - pad_n: bool - pad_k: bool - persistent: bool - - def is_valid(self) -> bool: - """Check if trait combination is valid""" - # Unsupported combinations - # Only 'mem' pipeline supports interwave scheduler. - # All compute pipelines (compv3/v4/v5/v6/async) only support intrawave. - unsupported = { - ("compv3", "cshuffle", "interwave"), - ("compv3", "default", "interwave"), - ("compv4", "cshuffle", "interwave"), - ("compv4", "default", "interwave"), - ("compv5", "cshuffle", "interwave"), - ("compv5", "default", "interwave"), - ("compv6", "cshuffle", "interwave"), - ("compv6", "default", "interwave"), - ("comp_async", "cshuffle", "interwave"), - ("comp_async", "default", "interwave"), - } - return (self.pipeline, self.epilogue, self.scheduler) not in unsupported + persistent: bool = False @dataclass @@ -345,89 +303,7 @@ class KernelConfig: # ============================================================================ -class TypeMappings: - """Centralized type mappings for code generation""" - - DTYPE_TO_CK = { - "fp16": "fp16_t", - "bf16": "bf16_t", - "fp32": "float", - "fp8": "fp8_t", - "bf8": "bf8_t", - "int8": "int8_t", - } - - # Fully-qualified types for use outside of 'using namespace ck_tile' scope - DTYPE_TO_CK_QUALIFIED = { - "fp16": "ck_tile::fp16_t", - "bf16": "ck_tile::bf16_t", - "fp32": "float", # Built-in type, no namespace - "fp8": "ck_tile::fp8_t", - "bf8": "ck_tile::bf8_t", - "int8": "int8_t", # Built-in type - } - - DTYPE_TO_DISPATCHER = { - "fp16": "DataType::FP16", - "bf16": "DataType::BF16", - "fp32": "DataType::FP32", - "fp8": "DataType::FP8", - "bf8": "DataType::BF8", - "int8": "DataType::INT8", - } - - LAYOUT_TO_CK = { - "r": "tensor_layout::gemm::RowMajor", - "c": "tensor_layout::gemm::ColumnMajor", - } - - LAYOUT_TO_DISPATCHER = { - "r": "LayoutTag::RowMajor", - "c": "LayoutTag::ColMajor", - } - - PIPELINE_TO_CK = { - "mem": "GemmPipelineAgBgCrMem", - "compv3": "GemmPipelineAgBgCrCompV3", - "compv4": "GemmPipelineAgBgCrCompV4", - "preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2", - } - - PIPELINE_TO_BASE = { - "mem": "BaseGemmPipelineAgBgCrMem", - "compv3": "BaseGemmPipelineAgBgCrCompV3", - "compv4": "BaseGemmPipelineAgBgCrCompV4", - "preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2", - } - - PIPELINE_TO_DISPATCHER = { - "mem": "Pipeline::Mem", - "compv3": "Pipeline::CompV3", - "compv4": "Pipeline::CompV4", - "preshufflev2": "Pipeline::PreShuffleV2", - } - - SCHEDULER_TO_CK = { - "intrawave": "GemmPipelineScheduler::Intrawave", - "interwave": "GemmPipelineScheduler::Interwave", - "default": "GemmPipelineScheduler::Default", - } - - SCHEDULER_TO_DISPATCHER = { - "intrawave": "Scheduler::Intrawave", - "interwave": "Scheduler::Interwave", - "default": "Scheduler::Auto", - } - - EPILOGUE_TO_DISPATCHER = { - "cshuffle": "Epilogue::CShuffle", - "default": "Epilogue::Default", - } - - @staticmethod - def get_output_dtype(dtype: str) -> str: - """Get output datatype (fp8/bf8 -> fp16)""" - return "fp16" if dtype in ["fp8", "bf8"] else dtype +# TypeMappings imported from codegen_common as CommonTypeMappings -> TypeMappings alias # ============================================================================ @@ -1068,7 +944,11 @@ class UnifiedGemmCodegen: } def generate_all(self, parallel: bool = True) -> Dict: - """Generate all kernels""" + """Generate all kernels. + + When parallel=True, all configs across all variants are collected first, + then generated concurrently in a single thread pool for maximum throughput. + """ log.info("Generating GEMM kernels:") log.info(f" Datatype: {self.datatype}") log.info(f" Layout: {self.layout}") @@ -1078,49 +958,24 @@ class UnifiedGemmCodegen: results = {"kernels": [], "wrappers": [], "failed": []} - # Get configurations + # Collect ALL configs across all variants/preselected sets upfront + all_configs = [] if self.use_preselected: - configs = self._get_preselected_configs() - log.info(f" Total configurations: {len(configs)}") + all_configs = self._get_preselected_configs() + log.info(f" Total configurations: {len(all_configs)}") else: for variant in self.variants: - log.info(f"\nGenerating {variant.value} kernels...") configs = self._get_configs_for_variant(variant) - log.info(f" Configurations: {len(configs)}") + log.info(f" {variant.value}: {len(configs)} configurations") + all_configs.extend(configs) + log.info(f" Total across all variants: {len(all_configs)}") - if parallel: - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [ - executor.submit(self._generate_one, cfg) for cfg in configs - ] - for future in concurrent.futures.as_completed(futures): - try: - k, w = future.result() - results["kernels"].append(k) - results["wrappers"].append(w) - except Exception as e: - results["failed"].append(str(e)) - log.error(f"Failed: {e}") - else: - for cfg in configs: - try: - k, w = self._generate_one(cfg) - results["kernels"].append(k) - results["wrappers"].append(w) - except Exception as e: - results["failed"].append(str(e)) - log.error(f"Failed: {e}") - - # Generate registration header - if results["wrappers"]: - self._generate_registration_header(results["wrappers"]) - - return results - - # Generate from preselected set - if parallel: + # Generate all configs in a single parallel pass + if parallel and all_configs: with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [executor.submit(self._generate_one, cfg) for cfg in configs] + futures = [ + executor.submit(self._generate_one, cfg) for cfg in all_configs + ] for future in concurrent.futures.as_completed(futures): try: k, w = future.result() @@ -1130,7 +985,7 @@ class UnifiedGemmCodegen: results["failed"].append(str(e)) log.error(f"Failed: {e}") else: - for cfg in configs: + for cfg in all_configs: try: k, w = self._generate_one(cfg) results["kernels"].append(k) @@ -1139,7 +994,6 @@ class UnifiedGemmCodegen: results["failed"].append(str(e)) log.error(f"Failed: {e}") - # Generate registration header if results["wrappers"]: self._generate_registration_header(results["wrappers"]) @@ -1638,12 +1492,19 @@ def main(): # Write to temp file and use as config import tempfile + import os as _os - with tempfile.NamedTemporaryFile( + _tmp_config = tempfile.NamedTemporaryFile( mode="w", suffix=".json", delete=False - ) as f: - json.dump(full_config, f) - args.config = Path(f.name) + ) + try: + json.dump(full_config, _tmp_config) + _tmp_config.close() + args.config = Path(_tmp_config.name) + except Exception: + _tmp_config.close() + _os.unlink(_tmp_config.name) + raise except json.JSONDecodeError as e: logging.error(f"Invalid tile-config-json: {e}") return 1 @@ -1672,7 +1533,7 @@ def main(): results = codegen.generate_all(parallel=not args.no_parallel) - logging.info("\n✅ Generation complete!") + logging.info("\nGeneration complete.") logging.info(f" Kernels: {len(results['kernels'])}") logging.info(f" Wrappers: {len(results['wrappers'])}") logging.info(f" Failed: {len(results['failed'])}") @@ -1684,7 +1545,7 @@ def main(): # Generate dispatcher registration if requested if args.register: - logging.info("\n📝 Generating dispatcher registration code...") + logging.info("\nGenerating dispatcher registration code...") try: from generate_dispatcher_registration import ( scan_generated_headers, @@ -1701,11 +1562,20 @@ def main(): ) generate_registration_cpp(kernels, reg_dir / "dispatcher_registration.cpp") - logging.info(f"✓ Generated registration code for {len(kernels)} kernels") + logging.info(f"Generated registration code for {len(kernels)} kernels") except Exception as e: logging.error(f"Failed to generate registration code: {e}") return 1 + # Clean up temp config file if we created one + if args.tile_config_json and args.config and args.config.exists(): + try: + import os as _os + + _os.unlink(args.config) + except OSError: + pass + return 0 if not results["failed"] else 1 diff --git a/dispatcher/codegen/unified_grouped_conv_codegen.py b/dispatcher/codegen/unified_grouped_conv_codegen.py new file mode 100644 index 0000000000..ff40cb4ed4 --- /dev/null +++ b/dispatcher/codegen/unified_grouped_conv_codegen.py @@ -0,0 +1,1757 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Unified Grouped Convolution Code Generator + +This is the unified code generator for all grouped convolution kernel variants: +- Forward grouped convolution +- Backward data grouped convolution +- Backward weight grouped convolution + +Generates both CK Tile kernels AND dispatcher wrappers. +Based on the GEMM codegen pattern. +""" + +import argparse +import logging +from pathlib import Path +from typing import List, Optional, Tuple, Union +from dataclasses import dataclass +from enum import Enum + +from codegen_common import ( + TileConfig, + TraitConfigBase, + parallel_generate, +) + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") +log = logging.getLogger(__name__) + +# Import architecture filter for GPU-specific validation +try: + from arch_filter import ArchFilter, OperatorType + + HAS_ARCH_FILTER = True +except ImportError: + HAS_ARCH_FILTER = False + ArchFilter = None + OperatorType = None + + +# ============================================================================ +# Configuration and Data Structures +# ============================================================================ + + +class GroupedConvVariant(Enum): + """Grouped convolution kernel variants""" + + FORWARD = "forward" + BACKWARD_DATA = "bwd_data" + BACKWARD_WEIGHT = "bwd_weight" + + +class GroupedConvLayout(Enum): + """Grouped convolution data layouts""" + + # 1D + NWGC = "NWGC" # Input/Output: N W G C + GKXC = "GKXC" # Weight: G K X C + NWGK = "NWGK" # Output: N W G K + + # 2D + NHWGC = "NHWGC" # Input: N H W G C + GKYXC = "GKYXC" # Weight: G K Y X C + NHWGK = "NHWGK" # Output: N H W G K + + # 3D + NDHWGC = "NDHWGC" # Input: N D H W G C + GKZYXC = "GKZYXC" # Weight: G K Z Y X C + NDHWGK = "NDHWGK" # Output: N D H W G K + + +@dataclass +class GroupedConvTraitConfig(TraitConfigBase): + """Kernel trait configuration for grouped convolution (extends TraitConfigBase). + + Conv-specific extensions beyond TraitConfigBase. These map to + GroupedConvTraits template parameters in grouped_convolution_utils.hpp: + - double_smem_buffer: ping-pong LDS for compute V4+ pipelines + - num_groups_to_merge: fuse multiple groups into one tile (NumGroupsToMerge) + - split_image: split spatial dims for large tensors (EnableSplitImage) + - explicit_gemm: use explicit GEMM path (ExplicitGemm) + - two_stage: two-stage bwd_weight with fp32 workspace + elementwise convert + + Note: CK Tile already uses long_index_t (64-bit) for group strides and + batch offsets, so there is no separate "large_tensor" flag. For large + spatial dimensions, use split_image=True instead. + """ + + double_smem_buffer: bool = False + num_groups_to_merge: int = 1 + split_image: bool = False + explicit_gemm: bool = False + two_stage: bool = False + + +# Backward compatibility alias +TraitConfig = GroupedConvTraitConfig + + +@dataclass +class GroupedConvKernelConfig: + """Complete grouped convolution kernel configuration""" + + tile: TileConfig + trait: GroupedConvTraitConfig + variant: GroupedConvVariant = GroupedConvVariant.FORWARD + ndim_spatial: int = 2 # 1D, 2D, or 3D + arch: str = "gfx942" # Target architecture + layout: Union[str, GroupedConvLayout] = ( + "nhwgc" # Data layout (e.g., "nhwgc", "ndhwgc") + ) + + # Vector sizes: a=4 for fp16 input (8-byte aligned global loads), + # b=8 for weight tensor, c=8 for output stores. These match the + # CK Tile default vectorization widths for fp16 on CDNA3 (gfx942). + vector_size_a: int = 4 + vector_size_b: int = 8 + vector_size_c: int = 8 + vector_sizes: Optional[Tuple[int, int, int]] = None + + # Occupancy parameters + block_per_cu: int = 1 + num_wave_groups: int = 1 + num_groups_to_merge: int = 1 + + # Double buffering + double_smem_buffer: bool = False + + def __post_init__(self): + if self.vector_sizes is not None: + self.vector_size_a, self.vector_size_b, self.vector_size_c = ( + self.vector_sizes[:3] + ) + # Sync trait fields with top-level fields (trait is source of truth + # when both are specified, but top-level overrides default trait values). + if self.double_smem_buffer and not self.trait.double_smem_buffer: + self.trait.double_smem_buffer = self.double_smem_buffer + elif self.trait.double_smem_buffer: + self.double_smem_buffer = self.trait.double_smem_buffer + if self.num_groups_to_merge != 1 and self.trait.num_groups_to_merge == 1: + self.trait.num_groups_to_merge = self.num_groups_to_merge + elif self.trait.num_groups_to_merge != 1: + self.num_groups_to_merge = self.trait.num_groups_to_merge + + def _layout_str(self) -> str: + """Get layout as lowercase string for naming.""" + if hasattr(self.layout, "value"): + return self.layout.value.lower() + return str(self.layout).lower() + + def name(self, datatype: str) -> str: + """ + Generate kernel name that uniquely identifies the kernel configuration. + + Format: grouped_conv_{variant}_{dtype}_{layout}_{ndim}d_{pipeline}_{epilogue}_{scheduler} + _{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k} + _{warp_tile_m}x{warp_tile_n}x{warp_tile_k} + [_vec{a}_{b}_{c}][_bpc{n}][_wg{n}][_gm{n}][_dsb][_pad{mnk}] + + All parameters that affect kernel behavior MUST be included to ensure + unique names for unique configurations: + - Variant (fwd/bwd_data/bwd_weight) + - Data type + - Layout (nhwgc, nchw, ndhwgc, etc.) + - Spatial dimensions (2d/3d) + - Pipeline, epilogue, scheduler + - Tile, warp, warp_tile dimensions + - Vector sizes, occupancy hints (if non-default) + - Double SMEM buffer, padding flags + """ + t = self.tile + tr = self.trait + layout_str = self._layout_str() + + variant_str = { + GroupedConvVariant.FORWARD: "fwd", + GroupedConvVariant.BACKWARD_DATA: "bwd_data", + GroupedConvVariant.BACKWARD_WEIGHT: "bwd_weight", + }[self.variant] + + # Core identity: variant, dtype, layout, dims + name = ( + f"grouped_conv_{variant_str}_{datatype}_{layout_str}_{self.ndim_spatial}d" + ) + + # Pipeline configuration + name += f"_{tr.pipeline}_{tr.epilogue}_{tr.scheduler}" + + # Block tile dimensions (M_Tile x N_Tile x K_Tile) + name += f"_{t.tile_m}x{t.tile_n}x{t.tile_k}" + + # Wave distribution (M_Warp x N_Warp x K_Warp) + name += f"_{t.warp_m}x{t.warp_n}x{t.warp_k}" + + # Warp tile dimensions (M_Warp_Tile x N_Warp_Tile x K_Warp_Tile) + name += f"_{t.warp_tile_m}x{t.warp_tile_n}x{t.warp_tile_k}" + + # Vector sizes (only if non-default) + if (self.vector_size_a, self.vector_size_b, self.vector_size_c) != (4, 8, 8): + name += ( + f"_vec{self.vector_size_a}_{self.vector_size_b}_{self.vector_size_c}" + ) + + # Occupancy hints (only if non-default) + if self.block_per_cu != 1: + name += f"_bpc{self.block_per_cu}" + + if self.num_wave_groups != 1: + name += f"_wg{self.num_wave_groups}" + + if self.num_groups_to_merge != 1: + name += f"_gm{self.num_groups_to_merge}" + + # Double SMEM buffer (for compute V4+) + if self.double_smem_buffer or tr.double_smem_buffer: + name += "_dsb" + + # Two-stage bwd_weight (fp32 workspace + elementwise convert) + if tr.two_stage: + name += "_2stage" + + # Padding suffix (only if not all enabled) + if not (tr.pad_m and tr.pad_n and tr.pad_k): + name += f"_pad{int(tr.pad_m)}{int(tr.pad_n)}{int(tr.pad_k)}" + + return name + + def is_valid_for_arch(self, arch: Optional[str] = None) -> bool: + """Check if configuration is valid for target architecture""" + target_arch = arch if arch is not None else self.arch + + # Check trait validity + if not self.trait.is_valid(): + return False + + # Backward operations have stricter pipeline requirements: + # - Backward weight: compv4/compv5 have transpose_tile2d issues + # - Backward data: compv4 has get_length issues in bwd_data kernel + # Both backward operations ONLY support compv3 and mem pipelines + if self.variant in ( + GroupedConvVariant.BACKWARD_WEIGHT, + GroupedConvVariant.BACKWARD_DATA, + ): + if self.trait.pipeline not in ("compv3", "mem"): + return False + + # Check warp configuration (from arch_specs) + try: + from arch_specs_generated import WARP_SUPPORTED_COMBINATIONS + + supported = WARP_SUPPORTED_COMBINATIONS.get(target_arch) + if supported is None: + return False # Unknown architecture + warp_cfg = [self.tile.warp_m, self.tile.warp_n, self.tile.warp_k] + if warp_cfg not in supported: + return False + except ImportError: + pass # Allow if arch_specs not available + + return True + + +# ============================================================================ +# Type Mappings +# ============================================================================ + + +class GroupedConvTypeMappings: + """Centralized type mappings for grouped convolution code generation""" + + DTYPE_TO_CK = { + "fp16": "half_t", + "bf16": "bf16_t", + "fp32": "float", + } + + # CK Tile conv pipelines (from conv_configs.hpp PipelineTypeTraits). + # basic_v1/mem/compv3 use GroupedConvUniversalPipelineAgBgCrPolicy; + # compv4/compv5/compv6/comp_async/basic_async_v1 use their own default policy. + PIPELINE_TO_CK = { + "basic_v1": "GemmPipeline::BASIC_V1", + "mem": "GemmPipeline::MEMORY", + "compv3": "GemmPipeline::COMPUTE_V3", + "compv4": "GemmPipeline::COMPUTE_V4", + "compv5": "GemmPipeline::COMPUTE_V5", + "compv6": "GemmPipeline::COMPUTE_V6", + "comp_async": "GemmPipeline::COMPUTE_ASYNC", + "basic_async_v1": "GemmPipeline::BASIC_ASYNC_V1", + } + + SCHEDULER_TO_CK = { + "intrawave": "GemmPipelineScheduler::Intrawave", + "interwave": "GemmPipelineScheduler::Interwave", + } + + LAYOUT_1D = { + "in": "tensor_layout::convolution::NWGC", + "wei": "tensor_layout::convolution::GKXC", + "out": "tensor_layout::convolution::NWGK", + } + + LAYOUT_2D = { + "in": "tensor_layout::convolution::NHWGC", + "wei": "tensor_layout::convolution::GKYXC", + "out": "tensor_layout::convolution::NHWGK", + } + + LAYOUT_3D = { + "in": "tensor_layout::convolution::NDHWGC", + "wei": "tensor_layout::convolution::GKZYXC", + "out": "tensor_layout::convolution::NDHWGK", + } + + @classmethod + def get_layouts(cls, ndim: int) -> dict: + if ndim == 1: + return cls.LAYOUT_1D + elif ndim == 2: + return cls.LAYOUT_2D + else: + return cls.LAYOUT_3D + + +# ============================================================================ +# CK Tile Grouped Conv Kernel Generator +# ============================================================================ + + +class CKTileGroupedConvKernelGenerator: + """Generates CK Tile grouped convolution kernel instance code""" + + def __init__( + self, + datatype: str, + variant: GroupedConvVariant = GroupedConvVariant.FORWARD, + ): + self.datatype = datatype + self.variant = variant + self.tm = GroupedConvTypeMappings() + + def generate(self, config: GroupedConvKernelConfig) -> str: + """Generate complete CK Tile grouped convolution kernel""" + kernel_name = config.name(self.datatype) + return f"""{self._header(kernel_name, config)} +{self._config_struct(config, kernel_name)} +{self._kernel_instance(config, kernel_name)} +""" + + def _header(self, kernel_name: str, config: GroupedConvKernelConfig) -> str: + """Generate header includes based on variant""" + if self.variant == GroupedConvVariant.BACKWARD_DATA: + kernel_header = "grouped_convolution_backward_data_kernel.hpp" + elif self.variant == GroupedConvVariant.BACKWARD_WEIGHT: + kernel_header = "grouped_convolution_backward_weight_kernel.hpp" + else: + kernel_header = "grouped_convolution_forward_kernel.hpp" + + elementwise_include = "" + if config.trait.two_stage: + elementwise_include = '\n#include "ck_tile/ops/elementwise.hpp"' + + return f"""// SPDX-License-Identifier: MIT +// Auto-generated CK Tile Grouped Convolution kernel: {kernel_name} +// Variant: {self.variant.value} +#pragma once + +#include +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/grouped_convolution/kernel/{kernel_header}" +#include "ck_tile/ops/grouped_convolution/pipeline/grouped_conv_universal_pipeline_ag_bg_cr_policy.hpp"{elementwise_include} + +using namespace ck_tile; +""" + + def _config_struct(self, config: GroupedConvKernelConfig, kernel_name: str) -> str: + """Generate config struct""" + t = config.tile + tr = config.trait + layouts = self.tm.get_layouts(config.ndim_spatial) + + return f""" +// Kernel configuration +struct {kernel_name}_Config {{ + // Data types + using InDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; + using WeiDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; + using AccDataType = float; + using OutDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; + + // Layouts + using InLayout = {layouts["in"]}; + using WeiLayout = {layouts["wei"]}; + using OutLayout = {layouts["out"]}; + + // Tile shape + static constexpr index_t M_Tile = {t.tile_m}; + static constexpr index_t N_Tile = {t.tile_n}; + static constexpr index_t K_Tile = {t.tile_k}; + + static constexpr index_t M_Warp = {t.warp_m}; + static constexpr index_t N_Warp = {t.warp_n}; + static constexpr index_t K_Warp = {t.warp_k}; + + static constexpr index_t M_Warp_Tile = {t.warp_tile_m}; + static constexpr index_t N_Warp_Tile = {t.warp_tile_n}; + static constexpr index_t K_Warp_Tile = {t.warp_tile_k}; + + // Vector sizes + static constexpr index_t VectorSizeA = {config.vector_size_a}; + static constexpr index_t VectorSizeB = {config.vector_size_b}; + static constexpr index_t VectorSizeC = {config.vector_size_c}; + + // Padding + static constexpr bool kPadM = {str(tr.pad_m).lower()}; + static constexpr bool kPadN = {str(tr.pad_n).lower()}; + static constexpr bool kPadK = {str(tr.pad_k).lower()}; + + // Pipeline & Epilogue + static constexpr auto Pipeline = {self.tm.PIPELINE_TO_CK[tr.pipeline]}; + static constexpr auto Scheduler = {self.tm.SCHEDULER_TO_CK[tr.scheduler]}; + static constexpr bool DoubleSmemBuffer = {str(tr.double_smem_buffer).lower()}; + static constexpr bool UseCShuffleEpilogue = {str(tr.epilogue == "cshuffle").lower()}; + + // Other params + static constexpr int kBlockPerCu = {config.block_per_cu}; + static constexpr index_t NumWaveGroups = {config.num_wave_groups}; + static constexpr index_t NumGroupsToMerge = {tr.num_groups_to_merge}; + static constexpr bool EnableSplitImage = {str(tr.split_image).lower()}; + static constexpr bool ExplicitGemm = {str(tr.explicit_gemm).lower()}; + static constexpr index_t NDimSpatial = {config.ndim_spatial}; + + // Target architecture + static constexpr const char* TargetArch = "{config.arch}"; +}}; +""" + + def _kernel_instance( + self, config: GroupedConvKernelConfig, kernel_name: str + ) -> str: + """Generate kernel instantiation code with launch function""" + tr = config.trait + + if self.variant == GroupedConvVariant.BACKWARD_WEIGHT and tr.two_stage: + return self._kernel_instance_two_stage(config, kernel_name) + + # Variant-specific configuration + if self.variant == GroupedConvVariant.BACKWARD_DATA: + host_args_type = "GroupedConvBwdDataHostArgs" + kernel_type = "GroupedConvolutionBackwardDataKernel" + gemm_traits = "GroupedConvImplicitGemmTraitsBwdData" + layout_suffix = "BwdData" + # For bwd_data: A=dOutput, B=Weight, C=dInput + a_dtype = "OutDataType" + b_dtype = "WeiDataType" + c_dtype = "InDataType" + gemm_k_calc = "args.K_ * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end()" + direction_prefix = "BWD_DATA" + launcher_alias = "SelectedConvBwdDataLauncher" + elif self.variant == GroupedConvVariant.BACKWARD_WEIGHT: + host_args_type = "GroupedConvBwdWeightHostArgs" + kernel_type = "GroupedConvolutionBackwardWeightKernel" + gemm_traits = "GroupedConvImplicitGemmTraitsBwdWeight" + layout_suffix = "BwdWeight" + # For bwd_weight: A=dOutput, B=Input, C=dWeight (per CK Tile invoker) + a_dtype = "OutDataType" + b_dtype = "InDataType" + c_dtype = "WeiDataType" + gemm_k_calc = "args.N_ * std::accumulate(args.output_spatial_lengths_.begin(), args.output_spatial_lengths_.end()" + direction_prefix = "BWD_WEIGHT" + launcher_alias = "SelectedConvBwdWeightLauncher" + else: # Forward + host_args_type = "GroupedConvFwdHostArgs<>" + kernel_type = "GroupedConvolutionForwardKernel" + gemm_traits = "GroupedConvImplicitGemmTraitsFwd" + layout_suffix = "Fwd" + a_dtype = "InDataType" + b_dtype = "WeiDataType" + c_dtype = "OutDataType" + gemm_k_calc = "args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end()" + direction_prefix = "FWD" + launcher_alias = "SelectedConvKernelLauncher" + + # Create valid C++ namespace name + ns_name = "ns_" + kernel_name.replace("-", "_") + + return f""" +// Unique namespace for this kernel to avoid conflicts when including multiple kernels +namespace {ns_name} {{ + +// Bring Config into namespace +using Config = {kernel_name}_Config; + +// Kernel name for identification +constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = "{kernel_name}"; + +// Selected kernel alias +using SelectedConv{direction_prefix.title()}Kernel = Config; + +// ============================================================================= +// Kernel Launch Implementation ({self.variant.value}) +// ============================================================================= + +struct {kernel_name}_Launcher {{ + using KernelConfig = Config; // Use the Config alias from namespace + using InDataType = typename Config::InDataType; + using WeiDataType = typename Config::WeiDataType; + using OutDataType = typename Config::OutDataType; + using AccDataType = typename Config::AccDataType; + using InLayout = typename Config::InLayout; + using WeiLayout = typename Config::WeiLayout; + using OutLayout = typename Config::OutLayout; + + static constexpr index_t NDimSpatial = Config::NDimSpatial; + + // Implicit GEMM shape + using GemmShape = TileGemmShape< + sequence, + sequence, + sequence>; + + // Convolution traits + static constexpr auto ConvSpec = ConvolutionSpecialization::Default; + using GroupedConvTraitsType = GroupedConvTraits< + NDimSpatial, ConvSpec, InLayout, WeiLayout, tuple<>, OutLayout, + Config::VectorSizeA, Config::VectorSizeB, Config::VectorSizeC, + Config::NumGroupsToMerge, Config::EnableSplitImage, Config::ExplicitGemm>; + + // Tile partitioner + using TilePartitioner = GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + + // Universal traits - layout suffix changes per variant + using GemmUniversalTraits = TileGemmUniversalTraits< + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, + Config::DoubleSmemBuffer, + typename GroupedConvTraitsType::AsLayout{layout_suffix}, + typename GroupedConvTraitsType::BsLayout{layout_suffix}, + typename GroupedConvTraitsType::CLayout{layout_suffix}, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, + Config::NumWaveGroups>; + + // Pipeline problem - data types change per variant + using GemmPipelineProblem = GemmPipelineProblem< + {a_dtype}, {b_dtype}, AccDataType, GemmShape, + typename GroupedConvTraitsType::template {gemm_traits}, + element_wise::PassThrough, element_wise::PassThrough, {c_dtype}, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>; + + // Base pipeline for tail handling + using BaseGemmPipeline = {self._get_base_pipeline(tr.pipeline)}; + + static float launch(const {host_args_type}& args, const stream_config& s) {{ + const index_t gemm_k = {gemm_k_calc}, 1, std::multiplies()); + + const index_t k_grain = args.k_batch * Config::K_Tile; + const index_t K_split = (gemm_k + k_grain - 1) / k_grain * Config::K_Tile; + const index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + constexpr auto scheduler = Config::Scheduler; + + using UniversalGemmProblem = UniversalGemmPipelineProblem< + {a_dtype}, {b_dtype}, AccDataType, GemmShape, GemmUniversalTraits, + scheduler, + element_wise::PassThrough, element_wise::PassThrough, {c_dtype}, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>; + + using GemmPipeline = {self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")}; + + using ConvEpilogue = CShuffleEpilogue, AccDataType, {c_dtype}, + typename GroupedConvTraitsType::ImplicitGemmDsLayout, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, + element_wise::PassThrough, + TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, + Config::M_Warp, Config::N_Warp, Config::M_Warp_Tile, + Config::N_Warp_Tile, Config::K_Warp_Tile, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + Config::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + Config::VectorSizeC, false, 1, Config::DoubleSmemBuffer>>; + + using Kernel = {kernel_type}< + GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{ + auto kargs = Kernel::MakeKernelArgs(args); + + if (!Kernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Arguments not supported for grouped conv kernel"); + }} + + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); + + ave_time = launch_kernel(s, make_kernel( + Kernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }}; + + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + return ave_time; + }} +}}; + +// Launcher alias for tile_engine compatibility +using {launcher_alias} = {kernel_name}_Launcher; + +}} // namespace {ns_name} + +// Export specific launcher to global namespace +using {kernel_name}_Launcher = {ns_name}::{kernel_name}_Launcher; + +// When used with -include compiler flag, export aliases to global namespace +#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE +using {launcher_alias} = {ns_name}::{launcher_alias}; +constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = {ns_name}::CONV_{direction_prefix}_KERNEL_NAME; +#endif +""" + + # Pipelines that accept GroupedConvUniversalPipelineAgBgCrPolicy + # as a second template parameter for conv-specific LDS layout. + # (from conv_configs.hpp PipelineTypeTraits -- basic_v1/mem/compv3) + # CompV4/V5/V6/comp_async/basic_async_v1 use their own default policies. + _CONV_POLICY_PIPELINES = {"basic_v1", "mem", "compv3"} + + def _get_pipeline(self, pipeline: str) -> str: + """Get pipeline class name.""" + pipelines = { + "basic_v1": "GemmPipelineAGmemBGmemCRegV1", + "mem": "GemmPipelineAgBgCrMem", + "compv3": "GemmPipelineAgBgCrCompV3", + "compv4": "GemmPipelineAgBgCrCompV4", + "compv5": "GemmPipelineAgBgCrCompV5", + "compv6": "GemmPipelineAgBgCrCompV6", + "comp_async": "GemmPipelineAgBgCrCompAsync", + "basic_async_v1": "GemmPipelineAGmemBGmemCRegAsyncV1", + } + return pipelines.get(pipeline, "GemmPipelineAgBgCrCompV3") + + def _get_pipeline_template_args(self, pipeline: str, problem_type: str) -> str: + """Get full template argument list for pipeline instantiation. + + For basic_v1/mem/compv3, passes GroupedConvUniversalPipelineAgBgCrPolicy + as a second template argument for conv-specific LDS banking. + """ + base = self._get_pipeline(pipeline) + if pipeline in self._CONV_POLICY_PIPELINES: + return f"{base}<{problem_type}, GroupedConvUniversalPipelineAgBgCrPolicy>" + return f"{base}<{problem_type}>" + + def _get_base_pipeline(self, pipeline: str) -> str: + """Get base pipeline class name (used for tail handling only). + + Note: basic_async_v1 inherits from BaseGemmPipelineAGmemBGmemCRegV1 + (there is no separate BaseGemmPipelineAGmemBGmemCRegAsyncV1). + """ + pipelines = { + "basic_v1": "BaseGemmPipelineAGmemBGmemCRegV1", + "mem": "BaseGemmPipelineAgBgCrMem", + "compv3": "BaseGemmPipelineAgBgCrCompV3", + "compv4": "BaseGemmPipelineAgBgCrCompV4", + "compv5": "BaseGemmPipelineAgBgCrCompV5", + "compv6": "BaseGemmPipelineAgBgCrCompV6", + "comp_async": "BaseGemmPipelineAgBgCrCompAsync", + "basic_async_v1": "BaseGemmPipelineAGmemBGmemCRegV1", + } + return pipelines.get(pipeline, "BaseGemmPipelineAgBgCrCompV3") + + def _kernel_instance_two_stage( + self, config: GroupedConvKernelConfig, kernel_name: str + ) -> str: + """Generate two-stage bwd_weight kernel: GEMM into fp32 workspace + ElementWise convert. + + Mirrors grouped_convolution_backward_weight_two_stage_invoker.hpp from + example/ck_tile/20_grouped_convolution/. + """ + tr = config.trait + ns_name = "ns_" + kernel_name.replace("-", "_") + direction_prefix = "BWD_WEIGHT" + launcher_alias = "SelectedConvBwdWeightLauncher" + + return f""" +namespace {ns_name} {{ + +using Config = {kernel_name}_Config; +constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = "{kernel_name}"; +using SelectedConv{direction_prefix.title()}Kernel = Config; + +struct {kernel_name}_Launcher {{ + using KernelConfig = Config; + using InDataType = typename Config::InDataType; + using WeiDataType = typename Config::WeiDataType; + using OutDataType = typename Config::OutDataType; + using AccDataType = typename Config::AccDataType; + using InLayout = typename Config::InLayout; + using WeiLayout = typename Config::WeiLayout; + using OutLayout = typename Config::OutLayout; + using WorkspaceDataType = float; + + static constexpr index_t NDimSpatial = Config::NDimSpatial; + // Two-stage forces VectorSizeC = 1 for workspace writes + static constexpr index_t VectorSizeC_TwoStage = 1; + + using GemmShape = TileGemmShape< + sequence, + sequence, + sequence>; + + static constexpr auto ConvSpec = ConvolutionSpecialization::Default; + using GroupedConvTraitsType = GroupedConvTraits< + NDimSpatial, ConvSpec, InLayout, WeiLayout, tuple<>, OutLayout, + Config::VectorSizeA, Config::VectorSizeB, VectorSizeC_TwoStage, + Config::NumGroupsToMerge, Config::EnableSplitImage, Config::ExplicitGemm>; + + using TilePartitioner = GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + + using GemmUniversalTraits = TileGemmUniversalTraits< + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, + Config::DoubleSmemBuffer, + typename GroupedConvTraitsType::AsLayoutBwdWeight, + typename GroupedConvTraitsType::BsLayoutBwdWeight, + typename GroupedConvTraitsType::CLayoutBwdWeight, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, + Config::NumWaveGroups>; + + using GemmPipelineProblem = GemmPipelineProblem< + OutDataType, InDataType, AccDataType, GemmShape, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight, + element_wise::PassThrough, element_wise::PassThrough, WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>; + + using BaseGemmPipeline = {self._get_base_pipeline(tr.pipeline)}; + + static float launch(const GroupedConvBwdWeightHostArgs& args, const stream_config& s) {{ + const index_t gemm_k = args.N_ * std::accumulate( + args.output_spatial_lengths_.begin(), args.output_spatial_lengths_.end(), + 1, std::multiplies()); + + const index_t k_grain = args.k_batch * Config::K_Tile; + const index_t K_split = (gemm_k + k_grain - 1) / k_grain * Config::K_Tile; + const index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + constexpr auto scheduler = Config::Scheduler; + + using UniversalGemmProblem = UniversalGemmPipelineProblem< + OutDataType, InDataType, AccDataType, GemmShape, GemmUniversalTraits, + scheduler, + element_wise::PassThrough, element_wise::PassThrough, WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>; + + using GemmPipeline = {self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")}; + + // Epilogue writes to fp32 workspace (not fp16 output) + using ConvEpilogue = CShuffleEpilogue, AccDataType, WorkspaceDataType, + typename GroupedConvTraitsType::ImplicitGemmDsLayout, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, + element_wise::PassThrough, + TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, + Config::M_Warp, Config::N_Warp, Config::M_Warp_Tile, + Config::N_Warp_Tile, Config::K_Warp_Tile, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + Config::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeC>>; + + using Kernel = GroupedConvolutionBackwardWeightKernel< + GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>; + + // ElementWise kernel: fp32 workspace -> fp16/bf16 output + using XElementwiseOp = element_wise::UnaryConvert; + using EwBlockTile = sequence<2048>; + using EwBlockWarps = sequence<8>; + using EwWarpTile = sequence<64>; + using EwShape = ElementWiseShape; + using EwProblem = ElementWisePipelineProblem< + WorkspaceDataType, WorkspaceDataType, WeiDataType, EwShape, XElementwiseOp>; + using EwKernel = ElementWiseKernel; + + // Workspace: G * K * C * product(filter_spatial) elements in fp32 + const index_t spatial_accum = std::accumulate( + args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end(), + 1, std::multiplies()); + DeviceMem ws_buf(args.G_ * args.K_ * args.C_ * spatial_accum * sizeof(WorkspaceDataType)); + + GroupedConvBwdWeightHostArgs ws_args(args); + auto* c_ptr = ws_args.wei_ptr; + ws_args.wei_ptr = ws_buf.GetDeviceBuffer(); + + auto kargs = Kernel::MakeKernelArgs(ws_args); + + if(!Kernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Arguments not supported for two-stage bwd_weight kernel"); + }} + + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); + + // ElementWise kernel setup + const index_t ew_block_size = EwKernel::BlockSize(); + const index_t total_elems = args.G_ * args.K_ * args.C_ * spatial_accum; + constexpr index_t elems_per_block = EwBlockTile::at(number<0>{{}}); + const index_t ew_grid_size = (total_elems + elems_per_block - 1) / elems_per_block; + + auto ew_shape = make_tuple(args.G_ * args.K_, + args.C_ * spatial_accum); + auto ew_inputs = make_tuple(static_cast(ws_args.wei_ptr)); + + if(!EwKernel::IsSupportedArgument(ew_shape)) {{ + throw std::runtime_error("ElementWise arguments not supported for two-stage convert"); + }} + + auto preprocess = [&]() {{ + if(kargs.k_batch > 1) + hip_check_error(hipMemsetAsync( + ws_args.wei_ptr, 0, + total_elems * sizeof(WorkspaceDataType), + s.stream_id_)); + }}; + + ave_time = launch_kernel_time_mask( + s, preprocess, + make_kernel(Kernel{{}}, grids, blocks, 0, kargs), + make_kernel( + EwKernel{{}}, ew_grid_size, ew_block_size, 0, + ew_shape, + make_tuple(args.C_ * spatial_accum, 1), + make_tuple(args.C_ * spatial_accum, 1), + ew_inputs, + static_cast(c_ptr))); + + return ave_time; + }} +}}; + +using {launcher_alias} = {kernel_name}_Launcher; + +}} // namespace {ns_name} + +using {kernel_name}_Launcher = {ns_name}::{kernel_name}_Launcher; + +#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE +using {launcher_alias} = {ns_name}::{launcher_alias}; +constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = {ns_name}::CONV_{direction_prefix}_KERNEL_NAME; +#endif +""" + + +# ============================================================================ +# Dispatcher Wrapper Generator +# ============================================================================ + + +class GroupedConvDispatcherWrapperGenerator: + """Generates dispatcher integration wrapper following GEMM pattern""" + + # Static mappings for pipeline and scheduler enum names (matches kernel_key.hpp) + PIPELINE_TO_DISPATCHER = { + "mem": "Pipeline::Mem", + "compv3": "Pipeline::CompV3", + "compv4": "Pipeline::CompV4", + "compv5": "Pipeline::CompV5", + "preshufflev1": "Pipeline::PreShuffleV1", + "preshufflev2": "Pipeline::PreShuffleV2", + } + + SCHEDULER_TO_DISPATCHER = { + "default": "Scheduler::Default", + "intrawave": "Scheduler::Intrawave", + "interwave": "Scheduler::Interwave", + } + + def __init__( + self, + datatype: str, + variant: GroupedConvVariant = GroupedConvVariant.FORWARD, + ): + self.datatype = datatype + self.variant = variant + + def _pipeline_to_dispatcher(self, pipeline: str) -> str: + """Convert pipeline string to dispatcher enum value""" + return self.PIPELINE_TO_DISPATCHER.get( + pipeline.lower(), f"Pipeline::{pipeline.capitalize()}" + ) + + def _scheduler_to_dispatcher(self, scheduler: str) -> str: + """Convert scheduler string to dispatcher enum value""" + return self.SCHEDULER_TO_DISPATCHER.get( + scheduler.lower(), f"Scheduler::{scheduler.capitalize()}" + ) + + def generate( + self, + config: GroupedConvKernelConfig, + kernel_path: Path, + output_dir: Path, + ) -> str: + """Generate dispatcher wrapper with factory function for registry""" + kernel_name = config.name(self.datatype) + rel_path = kernel_path.relative_to(output_dir) + + # Determine launcher type based on variant + if self.variant == GroupedConvVariant.FORWARD: + launcher_alias = "SelectedConvKernelLauncher" + host_args_type = "GroupedConvFwdHostArgs<>" + conv_type_str = "forward" + elif self.variant == GroupedConvVariant.BACKWARD_DATA: + launcher_alias = "SelectedConvBwdDataLauncher" + host_args_type = "GroupedConvBwdDataHostArgs" + conv_type_str = "bwd_data" + else: # BACKWARD_WEIGHT + launcher_alias = "SelectedConvBwdWeightLauncher" + host_args_type = "GroupedConvBwdWeightHostArgs" + conv_type_str = "bwd_weight" + + return f"""// SPDX-License-Identifier: MIT +// Auto-generated dispatcher wrapper for: {kernel_name} +#pragma once + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "../{rel_path}" + +namespace ck_tile {{ +namespace dispatcher {{ +namespace generated {{ + +using ::ck_tile::dispatcher::GroupedConvKernelInstancePtr; +using ::ck_tile::dispatcher::GroupedConvKernelKey; +using ::ck_tile::dispatcher::DataType; +using ::ck_tile::dispatcher::LayoutTag; +using ::ck_tile::dispatcher::Pipeline; +using ::ck_tile::dispatcher::Scheduler; +using ::ck_tile::dispatcher::Epilogue; +using Priority = ::ck_tile::dispatcher::GroupedConvRegistry::Priority; + +// Factory function to create kernel instance for registry +inline GroupedConvKernelInstancePtr make_{kernel_name}(const std::string& gfx_arch = "gfx942") {{ + GroupedConvKernelKey key; + key.signature.dtype_in = DataType::FP16; + key.signature.dtype_wei = DataType::FP16; + key.signature.dtype_out = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout = "nhwgc"; + key.signature.conv_type = "{conv_type_str}"; + key.signature.num_dims = {config.ndim_spatial}; + key.signature.groups = 1; + + key.algorithm.tile_shape = {{{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k}}}; + key.algorithm.wave_shape = {{{config.tile.warp_m}, {config.tile.warp_n}, 1}}; + key.algorithm.warp_tile_shape = {{{config.tile.warp_tile_m}, {config.tile.warp_tile_n}, {config.tile.warp_tile_k}}}; + key.algorithm.pipeline = {self._pipeline_to_dispatcher(config.trait.pipeline)}; + key.algorithm.scheduler = {self._scheduler_to_dispatcher(config.trait.scheduler)}; + key.algorithm.epilogue = Epilogue::CShuffle; + key.gfx_arch = gfx_arch; + + // Create kernel instance that wraps the launcher + return std::make_shared( + key, + "{kernel_name}", + []({host_args_type}& args, const stream_config& cfg) -> float {{ + return {kernel_name}_Launcher::launch(args, cfg); + }} + ); +}} + +}} // namespace generated +}} // namespace dispatcher +}} // namespace ck_tile + +// Export launcher alias to global namespace for direct use +using {launcher_alias} = {kernel_name}_Launcher; +""" + + +# ============================================================================ +# Configuration Parser +# ============================================================================ + + +def get_default_configs( + arch: str = "gfx942", + variants: Optional[List[GroupedConvVariant]] = None, + ndims: Optional[List[int]] = None, +) -> List[GroupedConvKernelConfig]: + """Get default grouped convolution configurations for target architecture""" + configs = [] + + if variants is None: + variants = [GroupedConvVariant.FORWARD] + if ndims is None: + ndims = [2] + + # Valid configurations per variant (based on CK Tile example configs) + # Forward and Backward Data: standard GEMM-like tiles + fwd_bwd_data_tiles = [ + # (tile_m, tile_n, tile_k, warp_m, warp_n, warp_tile_m, warp_tile_n, warp_tile_k) + (128, 128, 32, 2, 2, 32, 32, 16), # Standard 128x128 + (256, 256, 32, 2, 2, 32, 32, 16), # Large 256x256 + (64, 64, 32, 1, 4, 16, 16, 16), # Small 64x64 + (128, 64, 32, 2, 2, 32, 32, 16), # Rectangular + (16, 64, 64, 1, 4, 16, 16, 32), # Tall and narrow + ] + + # Backward Weight: VERY specific tile configs that work with CK Tile's bwd_weight kernel + # Based on ConvConfigComputeV3 from CK Tile examples (example/ck_tile/20_grouped_convolution/) + # Note: Backward weight has strict constraints on warp configurations due to transpose_tile2d + # Only specific warp configs work: (1, 4, 1) and (4, 1, 1) are known to work + bwd_weight_tiles = [ + # (tile_m, tile_n, tile_k, warp_m, warp_n, warp_tile_m, warp_tile_n, warp_tile_k) + # ConvConfigComputeV3: The primary working config for backward weight + (16, 64, 64, 1, 4, 16, 16, 32), + ] + + for variant in variants: + # Select tile configs based on variant + if variant == GroupedConvVariant.BACKWARD_WEIGHT: + tile_configs = bwd_weight_tiles + # Backward weight ONLY supports compv3 (compv4/compv5 have transpose_tile2d issues) + pipelines = [("compv3", "cshuffle")] + # Also generate two-stage variants (fp32 workspace + elementwise convert) + two_stage_flags = [False, True] + elif variant == GroupedConvVariant.BACKWARD_DATA: + tile_configs = fwd_bwd_data_tiles + # Backward data ONLY supports compv3 (compv4 has get_length issues in bwd_data kernel) + pipelines = [("compv3", "cshuffle")] + two_stage_flags = [False] + else: + tile_configs = fwd_bwd_data_tiles + # Only forward grouped convolution supports both compv3 and compv4 + pipelines = [("compv3", "cshuffle"), ("compv4", "cshuffle")] + two_stage_flags = [False] + for ndim in ndims: + for pipeline, epilogue in pipelines: + for ( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_tile_m, + warp_tile_n, + warp_tile_k, + ) in tile_configs: + for two_stage in two_stage_flags: + adj_tile_k = tile_k * 2 if pipeline == "compv4" else tile_k + + trait = GroupedConvTraitConfig( + pipeline=pipeline, + scheduler="intrawave", + epilogue=epilogue, + double_smem_buffer=(pipeline == "compv4"), + pad_m=True, + pad_n=True, + pad_k=True, + two_stage=two_stage, + ) + + if not trait.is_valid(): + continue + + config = GroupedConvKernelConfig( + tile=TileConfig( + tile_m=tile_m, + tile_n=tile_n, + tile_k=adj_tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=1, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + ), + trait=trait, + variant=variant, + ndim_spatial=ndim, + arch=arch, + ) + + if config.is_valid_for_arch(): + configs.append(config) + + return configs + + +def get_arch_filter(): + """Get arch filter if available""" + try: + from arch_filter import ArchFilter + + return ArchFilter + except ImportError: + return None + + +# ============================================================================ +# Main Generator +# ============================================================================ + + +class _GenItem: + """Item for parallel generation with progress logging.""" + + def __init__( + self, + idx: int, + total: int, + config: GroupedConvKernelConfig, + datatype: str, + variant: GroupedConvVariant, + ): + self.idx = idx + self.total = total + self.config = config + self.datatype = datatype + self.variant = variant + + def __str__(self) -> str: + return f"kernel {self.idx}/{self.total}: {self.config.name(self.datatype)}" + + +class UnifiedGroupedConvCodegen: + """Main grouped convolution code generator""" + + def __init__( + self, + output_dir: Path, + gpu_target: str = "gfx942", + datatype: str = "fp16", + ndim_spatial: int = 2, + enable_arch_filter: bool = True, + ): + self.output_dir = output_dir + self.output_dir.mkdir(parents=True, exist_ok=True) + + # Create wrapper directory for dispatcher integration + self.wrapper_dir = self.output_dir / "dispatcher_wrappers" + self.wrapper_dir.mkdir(parents=True, exist_ok=True) + + self.generated_files: List[Path] = [] + self.generated_wrappers: List[Path] = [] + self.gpu_target = gpu_target + self.datatype = datatype + self.ndim_spatial = ndim_spatial + + # Initialize architecture filter for GPU-specific validation + self.arch_filter = None + if enable_arch_filter and HAS_ARCH_FILTER: + try: + self.arch_filter = ArchFilter(gpu_target, strict_mode=False) + log.info(f"Architecture filter enabled for {gpu_target}") + except ValueError as e: + log.warning(f"Could not create arch filter: {e}") + + def _get_configs(self) -> List[GroupedConvKernelConfig]: + """Get configurations for this codegen's datatype and ndim_spatial.""" + return get_default_configs( + arch=self.gpu_target, + variants=[ + GroupedConvVariant.FORWARD, + GroupedConvVariant.BACKWARD_DATA, + GroupedConvVariant.BACKWARD_WEIGHT, + ], + ndims=[self.ndim_spatial], + ) + + def _get_operator_type( + self, variant: GroupedConvVariant + ) -> Optional["OperatorType"]: + """Map GroupedConvVariant to OperatorType for arch validation""" + if OperatorType is None: + return None + + variant_to_operator = { + GroupedConvVariant.FORWARD: OperatorType.CONV_FWD, + GroupedConvVariant.BACKWARD_DATA: OperatorType.CONV_BWD_DATA, + GroupedConvVariant.BACKWARD_WEIGHT: OperatorType.CONV_BWD_WEIGHT, + } + return variant_to_operator.get(variant, OperatorType.CONV_FWD) + + def is_config_valid( + self, config: GroupedConvKernelConfig, datatype: str = "fp16" + ) -> bool: + """Validate configuration against architecture constraints""" + if not self.arch_filter or not HAS_ARCH_FILTER: + return True + + operator = self._get_operator_type(config.variant) + + return self.arch_filter.is_kernel_valid( + datatype_a=datatype, + datatype_b=datatype, + datatype_c=datatype, + tile_m=config.tile.tile_m, + tile_n=config.tile.tile_n, + tile_k=config.tile.tile_k, + warp_m=config.tile.warp_m, + warp_n=config.tile.warp_n, + warp_k=1, # Grouped conv typically uses warp_k=1 + warp_tile_m=config.tile.warp_tile_m, + warp_tile_n=config.tile.warp_tile_n, + warp_tile_k=config.tile.warp_tile_k, + pipeline=config.trait.pipeline, + epilogue=config.trait.epilogue, + scheduler=config.trait.scheduler, + operator=operator, + ) + + def generate_kernel( + self, + config: GroupedConvKernelConfig, + datatype: str, + variant: GroupedConvVariant = GroupedConvVariant.FORWARD, + ) -> Tuple[Path, Path]: + """Generate a single kernel file and dispatcher wrapper. Returns (kernel_path, wrapper_path).""" + kernel_gen = CKTileGroupedConvKernelGenerator(datatype, variant) + wrapper_gen = GroupedConvDispatcherWrapperGenerator(datatype, variant) + + kernel_name = config.name(datatype) + filename = f"{kernel_name}.hpp" + filepath = self.output_dir / filename + + # Generate kernel header + content = kernel_gen.generate(config) + filepath.write_text(content) + self.generated_files.append(filepath) + + # Generate dispatcher wrapper + wrapper_content = wrapper_gen.generate(config, filepath, self.output_dir) + wrapper_path = self.wrapper_dir / f"dispatcher_wrapper_{kernel_name}.hpp" + wrapper_path.write_text(wrapper_content) + self.generated_wrappers.append(wrapper_path) + + # Generate .cpp compilation unit for per-kernel parallel builds + cpp_filename = f"{kernel_name}.cpp" + cpp_filepath = self.output_dir / cpp_filename + cpp_content = f"""// SPDX-License-Identifier: MIT +// Auto-generated compilation unit for: {kernel_name} +// Enables per-kernel parallel compilation with make -j + +#include "{filename}" + +namespace ck_tile {{ namespace generated {{ + volatile bool _{kernel_name.replace("-", "_")}_loaded = true; +}} }} +""" + cpp_filepath.write_text(cpp_content) + + return filepath, wrapper_path + + def _generate_single_kernel(self, item: _GenItem): + """Generate one kernel (used by parallel_generate). Returns (kernel_path, wrapper_path) or raises.""" + kernel_path, wrapper_path = self.generate_kernel( + item.config, item.datatype, item.variant + ) + log.info( + "Generated kernel %d/%d: %s", + item.idx, + item.total, + item.config.name(item.datatype), + ) + return (kernel_path, wrapper_path) + + def generate_all( + self, + configs: Optional[List[GroupedConvKernelConfig]] = None, + datatypes: Optional[List[str]] = None, + parallel: bool = True, + ) -> dict: + """Generate all kernel files (optionally in parallel). + + Configs are filtered using architecture validation before generation. + Returns dict with keys: kernels, wrappers, failed. + """ + if configs is None: + configs = self._get_configs() + if datatypes is None: + datatypes = [self.datatype] + + results = {"kernels": [], "wrappers": [], "failed": []} + + # Filter configs using arch validation + valid_tasks = [] + rejected_count = 0 + + for datatype in datatypes: + for config in configs: + if self.is_config_valid(config, datatype): + valid_tasks.append((config, datatype, config.variant)) + else: + rejected_count += 1 + log.debug( + f"Rejected config for {self.gpu_target}: " + f"{config.tile.tile_m}x{config.tile.tile_n}x{config.tile.tile_k} " + f"variant={config.variant.value}" + ) + + if rejected_count > 0: + log.info( + f"Filtered {rejected_count} configs for {self.gpu_target}, " + f"{len(valid_tasks)} remaining" + ) + + total = len(valid_tasks) + items = [ + _GenItem(i, total, config, datatype, variant) + for i, (config, datatype, variant) in enumerate(valid_tasks) + ] + + def _safe_generate(item: _GenItem): + """Wrapper that catches exceptions for failure tracking.""" + try: + k, w = self._generate_single_kernel(item) + return ("ok", k, w, None) + except Exception as e: + return ("fail", None, None, str(e)) + + raw = parallel_generate( + _safe_generate, items, parallel=parallel and len(items) > 1 + ) + for r in raw: + if r[0] == "ok": + results["kernels"].append(r[1]) + results["wrappers"].append(r[2]) + else: + results["failed"].append(r[3]) + log.error("Failed: %s", r[3]) + + # Generate include_all_*.hpp headers for Python ctypes libraries + if results["wrappers"]: + self._generate_include_all_headers() + + return results + + def _generate_include_all_headers(self): + """Generate include_all_grouped_conv_*.hpp headers and registration header""" + # Scan output directory for ALL kernel files (not just this run's generated_files) + # This handles the case where fwd and bwd kernels are generated in separate make targets + fwd_headers = [] + bwd_data_headers = [] + bwd_weight_headers = [] + fwd_kernels = [] + bwd_data_kernels = [] + bwd_weight_kernels = [] + + for filepath in self.output_dir.glob("grouped_conv_*.hpp"): + name = filepath.name + kernel_name = name[:-4] + if name.startswith("grouped_conv_fwd_"): + fwd_headers.append(name) + fwd_kernels.append(kernel_name) + elif name.startswith(("grouped_conv_bwd_data_", "grouped_conv_bwdd_")): + bwd_data_headers.append(name) + bwd_data_kernels.append(kernel_name) + elif name.startswith(("grouped_conv_bwd_weight_", "grouped_conv_bwdw_")): + bwd_weight_headers.append(name) + bwd_weight_kernels.append(kernel_name) + + headers_to_generate = [ + ("include_all_grouped_conv_fwd_kernels.hpp", fwd_headers, "forward"), + ( + "include_all_grouped_conv_bwd_data_kernels.hpp", + bwd_data_headers, + "backward data", + ), + ( + "include_all_grouped_conv_bwd_weight_kernels.hpp", + bwd_weight_headers, + "backward weight", + ), + ] + + for header_name, kernel_headers, variant_desc in headers_to_generate: + header_path = self.output_dir / header_name + includes = "\n".join(f'#include "{h}"' for h in sorted(kernel_headers)) + + # Pick the first kernel as the default Selected*Launcher + if kernel_headers: + first_kernel = sorted(kernel_headers)[0][:-4] # Remove .hpp + if variant_desc == "forward": + launcher_alias = ( + f"using SelectedConvKernelLauncher = {first_kernel}_Launcher;" + ) + elif variant_desc == "backward data": + launcher_alias = ( + f"using SelectedConvBwdDataLauncher = {first_kernel}_Launcher;" + ) + else: # backward weight + launcher_alias = f"using SelectedConvBwdWeightLauncher = {first_kernel}_Launcher;" + else: + launcher_alias = "// No kernels generated for this variant" + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Auto-generated header for grouped conv {variant_desc} kernels +#pragma once + +{includes} + +// Default launcher alias (uses first kernel) +{launcher_alias} +""" + header_path.write_text(content) + if kernel_headers: + log.info(f"Generated: {header_name} ({len(kernel_headers)} kernels)") + + # Generate registration header (following GEMM pattern) + self._generate_registration_header( + fwd_kernels, bwd_data_kernels, bwd_weight_kernels + ) + + def _generate_registration_header( + self, + fwd_kernels: List[str], + bwd_data_kernels: List[str], + bwd_weight_kernels: List[str], + ): + """Generate master registration header for all grouped conv kernels""" + # Scan wrapper directory for ALL wrapper files + all_wrappers = [] + for wrapper_path in self.wrapper_dir.glob( + "dispatcher_wrapper_grouped_conv_*.hpp" + ): + all_wrappers.append(wrapper_path.name) + + wrapper_includes = "\n".join(f'#include "{w}"' for w in sorted(all_wrappers)) + + # Generate registration calls + fwd_registrations = "\n ".join( + f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);" + for k in sorted(fwd_kernels) + ) + bwd_data_registrations = "\n ".join( + f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);" + for k in sorted(bwd_data_kernels) + ) + bwd_weight_registrations = "\n ".join( + f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);" + for k in sorted(bwd_weight_kernels) + ) + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Auto-generated master registration header for grouped conv kernels +#pragma once + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" + +{wrapper_includes} + +namespace ck_tile {{ +namespace dispatcher {{ + +using Priority = GroupedConvRegistry::Priority; + +inline void register_all_grouped_conv_fwd_kernels( + const std::string& gfx_arch = "gfx942", + Priority priority = Priority::Normal) +{{ + auto& registry = GroupedConvRegistry::instance(); + {fwd_registrations if fwd_registrations else "// No forward kernels"} +}} + +inline void register_all_grouped_conv_bwd_data_kernels( + const std::string& gfx_arch = "gfx942", + Priority priority = Priority::Normal) +{{ + auto& registry = GroupedConvRegistry::instance(); + {bwd_data_registrations if bwd_data_registrations else "// No backward data kernels"} +}} + +inline void register_all_grouped_conv_bwd_weight_kernels( + const std::string& gfx_arch = "gfx942", + Priority priority = Priority::Normal) +{{ + auto& registry = GroupedConvRegistry::instance(); + {bwd_weight_registrations if bwd_weight_registrations else "// No backward weight kernels"} +}} + +inline void register_all_grouped_conv_kernels( + const std::string& gfx_arch = "gfx942", + Priority priority = Priority::Normal) +{{ + register_all_grouped_conv_fwd_kernels(gfx_arch, priority); + register_all_grouped_conv_bwd_data_kernels(gfx_arch, priority); + register_all_grouped_conv_bwd_weight_kernels(gfx_arch, priority); +}} + +inline std::size_t get_grouped_conv_fwd_kernel_count() {{ return {len(fwd_kernels)}; }} +inline std::size_t get_grouped_conv_bwd_data_kernel_count() {{ return {len(bwd_data_kernels)}; }} +inline std::size_t get_grouped_conv_bwd_weight_kernel_count() {{ return {len(bwd_weight_kernels)}; }} +inline std::size_t get_grouped_conv_kernel_count() {{ return {len(fwd_kernels) + len(bwd_data_kernels) + len(bwd_weight_kernels)}; }} + +}} // namespace dispatcher +}} // namespace ck_tile +""" + reg_path = self.wrapper_dir / "register_all_grouped_conv_kernels.hpp" + reg_path.write_text(content) + log.info(f"Generated registration header: {reg_path}") + + +# ============================================================================ +# CLI +# ============================================================================ + + +def main(): + parser = argparse.ArgumentParser( + description="Unified Grouped Convolution Code Generator" + ) + parser.add_argument( + "--output", + "-o", + type=Path, + default=Path("build/generated_kernels"), + help="Output directory", + ) + parser.add_argument( + "--datatype", + "-d", + type=str, + nargs="+", + default=["fp16"], + choices=["fp16", "bf16", "fp32"], + help="Data types to generate", + ) + parser.add_argument( + "--variant", + "-v", + type=str, + nargs="+", + default=["forward"], + choices=["forward", "bwd_data", "bwd_weight"], + help="Grouped convolution variants", + ) + parser.add_argument( + "--ndim", + "-n", + type=int, + nargs="+", + default=[2], + choices=[1, 2, 3], + help="Spatial dimensions", + ) + parser.add_argument( + "--arch", + "-a", + type=str, + default="gfx942", + choices=["gfx90a", "gfx942", "gfx950", "gfx1201"], + help="Target GPU architecture", + ) + parser.add_argument("--verbose", action="store_true", help="Verbose output") + parser.add_argument( + "--list-configs", + action="store_true", + help="List configurations without generating", + ) + + # Individual kernel configuration (when not using predefined configs) + parser.add_argument("--tile-m", type=int, help="Block tile M dimension") + parser.add_argument("--tile-n", type=int, help="Block tile N dimension") + parser.add_argument("--tile-k", type=int, help="Block tile K dimension") + parser.add_argument("--warp-m", type=int, help="Wave distribution M") + parser.add_argument("--warp-n", type=int, help="Wave distribution N") + parser.add_argument("--warp-k", type=int, default=1, help="Wave distribution K") + parser.add_argument("--warp-tile-m", type=int, help="Warp tile M") + parser.add_argument("--warp-tile-n", type=int, help="Warp tile N") + parser.add_argument("--warp-tile-k", type=int, default=16, help="Warp tile K") + parser.add_argument( + "--pipeline", + type=str, + choices=["mem", "compv3", "compv4", "compv5"], + help="Pipeline type", + ) + parser.add_argument( + "--scheduler", + type=str, + choices=["intrawave", "interwave"], + help="Scheduler type", + ) + parser.add_argument( + "--epilogue", + type=str, + default="cshuffle", + choices=["cshuffle", "default"], + help="Epilogue type", + ) + parser.add_argument("--pad-m", type=bool, default=True, help="Pad M dimension") + parser.add_argument("--pad-n", type=bool, default=True, help="Pad N dimension") + parser.add_argument("--pad-k", type=bool, default=True, help="Pad K dimension") + parser.add_argument("--vector-a", type=int, default=4, help="Vector size A") + parser.add_argument("--vector-b", type=int, default=8, help="Vector size B") + parser.add_argument("--vector-c", type=int, default=8, help="Vector size C") + parser.add_argument("--block-per-cu", type=int, default=1, help="Blocks per CU") + parser.add_argument("--num-wave-groups", type=int, default=1, help="Wave groups") + parser.add_argument( + "--num-groups-to-merge", type=int, default=1, help="Groups to merge" + ) + parser.add_argument( + "--double-smem-buffer", + type=str, + default=None, + help="Double SMEM buffer (true/false)", + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + # Map variant strings to enums + variant_map = { + "forward": GroupedConvVariant.FORWARD, + "bwd_data": GroupedConvVariant.BACKWARD_DATA, + "bwd_weight": GroupedConvVariant.BACKWARD_WEIGHT, + } + requested_variants = [variant_map[v] for v in args.variant] + + # Check if user specified custom configuration + custom_config = ( + args.tile_m is not None or args.tile_n is not None or args.pipeline is not None + ) + + if custom_config: + # Build custom config from CLI arguments + tile = TileConfig( + tile_m=args.tile_m or 128, + tile_n=args.tile_n or 128, + tile_k=args.tile_k or 64, + warp_m=args.warp_m or 2, + warp_n=args.warp_n or 2, + warp_k=args.warp_k or 1, + warp_tile_m=args.warp_tile_m or 32, + warp_tile_n=args.warp_tile_n or 32, + warp_tile_k=args.warp_tile_k or 16, + ) + pipeline = args.pipeline or "compv4" + # Determine double_smem_buffer: use CLI arg if given, else default based on pipeline + if args.double_smem_buffer is not None: + dsb = args.double_smem_buffer.lower() == "true" + else: + dsb = pipeline == "compv4" # compv4 requires double buffer + + trait = GroupedConvTraitConfig( + pipeline=pipeline, + scheduler=args.scheduler or "intrawave", + epilogue=args.epilogue or "cshuffle", + pad_m=args.pad_m, + pad_n=args.pad_n, + pad_k=args.pad_k, + double_smem_buffer=dsb, + num_groups_to_merge=args.num_groups_to_merge, + ) + config = GroupedConvKernelConfig( + tile=tile, + trait=trait, + variant=requested_variants[0] + if requested_variants + else GroupedConvVariant.FORWARD, + ndim_spatial=args.ndim[0] if args.ndim else 2, + arch=args.arch, + vector_size_a=args.vector_a, + vector_size_b=args.vector_b, + vector_size_c=args.vector_c, + block_per_cu=args.block_per_cu, + num_wave_groups=args.num_wave_groups, + ) + filtered_configs = [config] + else: + # Get predefined configurations for target arch with requested variants and ndims + filtered_configs = get_default_configs( + arch=args.arch, variants=requested_variants, ndims=args.ndim + ) + + if args.list_configs: + print(f"Grouped convolution configurations for {args.arch}:") + print(f" Datatypes: {args.datatype}") + print(f" Variants: {args.variant}") + print(f" Spatial dims: {args.ndim}") + print(f"\nConfigurations ({len(filtered_configs)}):") + for cfg in filtered_configs: + print(f" - {cfg.name('fp16')}") + print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}") + print(f" Warp: {cfg.tile.warp_m}x{cfg.tile.warp_n}x{cfg.tile.warp_k}") + print( + f" WarpTile: {cfg.tile.warp_tile_m}x{cfg.tile.warp_tile_n}x{cfg.tile.warp_tile_k}" + ) + print( + f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}, Scheduler: {cfg.trait.scheduler}" + ) + print( + f" Padding: M={cfg.trait.pad_m}, N={cfg.trait.pad_n}, K={cfg.trait.pad_k}" + ) + return + + # Generate + codegen = UnifiedGroupedConvCodegen( + output_dir=args.output, + gpu_target=args.arch, + enable_arch_filter=True, + ) + results = codegen.generate_all( + configs=filtered_configs, datatypes=args.datatype, parallel=True + ) + + print( + f"\nGenerated {len(results['kernels'])} grouped convolution kernel files " + f"for {args.arch} in {args.output}" + ) + if results["failed"]: + print(f" Failed: {len(results['failed'])}") + for err in results["failed"][:5]: + print(f" - {err}") + + +if __name__ == "__main__": + main() diff --git a/dispatcher/examples/CMakeLists.txt b/dispatcher/examples/CMakeLists.txt index bda8eb0372..ab094e90cf 100644 --- a/dispatcher/examples/CMakeLists.txt +++ b/dispatcher/examples/CMakeLists.txt @@ -187,7 +187,6 @@ function(add_gpu_example NAME SOURCE KERNEL_HEADER) if(HEADER_NAME STREQUAL "register_all_kernels.hpp") # Registration header - examples include it directly target_compile_options(${NAME} PRIVATE - -DGEMM_KERNEL_AVAILABLE=1 -mllvm -enable-noalias-to-md-conversion=0 -Wno-undefined-func-template -Wno-float-equal @@ -315,6 +314,7 @@ function(add_declarative_gpu_example NAME SOURCE) target_include_directories(${NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${CMAKE_CURRENT_SOURCE_DIR}/../.. ${EXAMPLE_KERNEL_DIR} ${EXAMPLE_KERNEL_DIR}/dispatcher_wrappers ) @@ -322,7 +322,6 @@ function(add_declarative_gpu_example NAME SOURCE) # Force-include the generated registration header target_compile_options(${NAME} PRIVATE -include ${EXAMPLE_HEADER} - -DGEMM_KERNEL_AVAILABLE=1 -mllvm -enable-noalias-to-md-conversion=0 -Wno-undefined-func-template -Wno-float-equal @@ -345,6 +344,7 @@ add_declarative_gpu_example(gemm_03_benchmark_validation gemm/cpp/03_benchmark_v add_declarative_gpu_example(gemm_04_heuristics gemm/cpp/04_heuristics.cpp) add_declarative_gpu_example(gemm_05_json_export gemm/cpp/05_json_export.cpp) add_declarative_gpu_example(gemm_06_multi_registry gemm/cpp/06_multi_registry.cpp) +add_declarative_gpu_example(gemm_07_gfx950_minimal gemm/cpp/07_gfx950_minimal.cpp) # ML Heuristic example -- requires LightGBM shared library # Derive site-packages from active Python interpreter (respects virtualenvs) @@ -443,19 +443,79 @@ if(hip_FOUND) endif() add_dependencies(dispatcher_gemm_lib generate_gemm_fallback_kernel) +# ============================================================================= +# Grouped Convolution C++ Examples +# ============================================================================= + +add_declarative_gpu_example(grouped_conv_01_basic grouped_conv/cpp/01_basic_grouped_conv.cpp) +add_declarative_gpu_example(grouped_conv_02_all_dirs grouped_conv/cpp/02_all_directions.cpp) +add_declarative_gpu_example(grouped_conv_03_bench_val grouped_conv/cpp/03_benchmark_validation.cpp) +add_declarative_gpu_example(grouped_conv_04_registry_json grouped_conv/cpp/04_registry_json.cpp) +add_declarative_gpu_example(grouped_conv_05_bwd_data grouped_conv/cpp/05_bwd_data.cpp) +add_declarative_gpu_example(grouped_conv_06_bwd_weight grouped_conv/cpp/06_bwd_weight.cpp) +add_declarative_gpu_example(grouped_conv_07_benchmark grouped_conv/cpp/07_multi_tile_benchmark.cpp) + +# ============================================================================= +# Grouped Convolution Python Library - Multi-Kernel (fwd/bwd_data/bwd_weight x 2D/3D) +# ============================================================================= + +# Kernel output directory for the Python conv library +set(CONV_FALLBACK_KERNEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/conv_python_fallback") +set(CONV_DISPATCH_HEADER "${CONV_FALLBACK_KERNEL_DIR}/conv_python_dispatch.hpp") + +# Generate ALL conv kernels (fwd/bwd_data/bwd_weight x 2D/3D x multiple tile configs) +# then create the dispatch header with 2D/3D aliases +add_custom_command( + OUTPUT ${CONV_DISPATCH_HEADER} + COMMAND ${CMAKE_COMMAND} -E make_directory ${CONV_FALLBACK_KERNEL_DIR} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_grouped_conv_codegen.py + --variant forward bwd_data bwd_weight --ndim 2 3 + --datatype fp16 --arch ${GPU_TARGET} + --output ${CONV_FALLBACK_KERNEL_DIR} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/generate_conv_dispatch_header.py + --kernel-dir ${CONV_FALLBACK_KERNEL_DIR} + --output ${CONV_DISPATCH_HEADER} + COMMENT "Generating conv kernels (fwd/bwd_data/bwd_weight x 2D/3D) for Python library..." + VERBATIM +) + +add_custom_target(generate_conv_fallback_kernels DEPENDS ${CONV_DISPATCH_HEADER}) + +# Conv dynamic library for Python (all 6 kernel variants) +add_library(dispatcher_conv_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/../bindings/ctypes/conv_ctypes_lib.cpp) +target_link_libraries(dispatcher_conv_lib PRIVATE ck_tile_dispatcher) +target_include_directories(dispatcher_conv_lib PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${CONV_FALLBACK_KERNEL_DIR} +) +target_compile_options(dispatcher_conv_lib PRIVATE + -include ${CONV_DISPATCH_HEADER} + -DGFX_ARCH="${GPU_TARGET}" + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress +) +if(hip_FOUND) + target_link_libraries(dispatcher_conv_lib PRIVATE hip::device hip::host) +endif() +add_dependencies(dispatcher_conv_lib generate_conv_fallback_kernels) + message(STATUS "GEMM examples configured - kernels will be generated during 'make'") +message(STATUS "Grouped Conv examples configured - kernels will be generated during 'make'") # Convenience target to build all Python ctypes libraries add_custom_target(python_libs - DEPENDS dispatcher_gemm_lib - COMMENT "Building Python ctypes libraries (GEMM)" + DEPENDS dispatcher_gemm_lib dispatcher_conv_lib + COMMENT "Building Python ctypes libraries (GEMM + Conv)" ) # ============================================================================= # Per-Architecture Kernel Generation Targets # ============================================================================= -set(SUPPORTED_GPU_ARCHS gfx942 gfx90a gfx1100 gfx1030) +set(SUPPORTED_GPU_ARCHS gfx942 gfx950 gfx90a gfx1100 gfx1030) foreach(ARCH ${SUPPORTED_GPU_ARCHS}) # GEMM kernels for this arch diff --git a/dispatcher/examples/README.md b/dispatcher/examples/README.md index fdee9c3583..24bea821ba 100644 --- a/dispatcher/examples/README.md +++ b/dispatcher/examples/README.md @@ -1,8 +1,6 @@ # CK Tile Dispatcher Examples -Comprehensive examples for GEMM operations with GPU execution. - -> **Note**: Convolution examples have been moved to `ck-2/conv_archive/` for reference. +Comprehensive examples for GEMM and Grouped Convolution operations with GPU execution. --- @@ -60,11 +58,11 @@ python3 examples/gemm/python/08_heuristics.py ``` examples/ -├── gemm/ -│ ├── cpp/ # 6 C++ GEMM examples -│ └── python/ # 11 Python GEMM examples -│ -└── README.md +|---- gemm/ +| |---- cpp/ # 6 C++ GEMM examples +| +---- python/ # 11 Python GEMM examples +| ++---- README.md ``` --- @@ -201,10 +199,31 @@ rocminfo | grep "Name:" --- -## Archived Examples +## Grouped Convolution -Convolution examples have been archived to `ck-2/conv_archive/dispatcher/`: -- `examples/conv/cpp/` - 11 C++ convolution examples -- `examples/conv/python/` - 14 Python convolution examples +Grouped convolution support has been re-introduced with a unified infrastructure shared with GEMM. -See the archive for convolution functionality reference. +### Infrastructure + +The grouped convolution code generation, utilities, and build scripts are available: + +| Component | Location | +|-----------|----------| +| C++ Headers | `include/ck_tile/dispatcher/grouped_conv_*.hpp` | +| Python Codegen | `codegen/unified_grouped_conv_codegen.py` | +| Python Utils | `python/grouped_conv_utils.py` | +| Build Script | `scripts/compile_grouped_conv_examples.py` | + +### Building Grouped Conv Kernels + +```bash +# Generate grouped conv kernels +python3 codegen/unified_grouped_conv_codegen.py \ + --output-dir build/generated_kernels \ + --datatype fp16 --variant forward --ndim-spatial 2 + +# Compile a grouped conv example +python3 scripts/compile_grouped_conv_examples.py my_grouped_conv_example.cpp +``` + +See the [main README](../README.md#grouped-convolution-support) for more details. diff --git a/dispatcher/examples/gemm/cpp/02_multi_size.cpp b/dispatcher/examples/gemm/cpp/02_multi_size.cpp index 5e620209f4..ffd2858be4 100644 --- a/dispatcher/examples/gemm/cpp/02_multi_size.cpp +++ b/dispatcher/examples/gemm/cpp/02_multi_size.cpp @@ -21,9 +21,9 @@ * - pipeline: "compv3" -> 1 option (compv4 requires special handling) * - scheduler: "intrawave" -> 1 option * - * Raw expansion: 3 × 2 = 6 configs, but arch filter validates each: - * - tile_m must be divisible by (warp_m × warp_tile_m) - * - tile_n must be divisible by (warp_n × warp_tile_n) + * Raw expansion: 3 x 2 = 6 configs, but arch filter validates each: + * - tile_m must be divisible by (warp_m x warp_tile_m) + * - tile_n must be divisible by (warp_n x warp_tile_n) * - Some wave/warp combos invalid: (4,1,1)+(32,32,16), (1,4,1)+(32,32,16) * Result: 4 valid wildcard kernels + 1 explicit = 5 total * @@ -70,13 +70,13 @@ DECL_KERNEL_SET(multi_size_kernels, .add(Signature().dtype("fp16").layout("rcr"), Algorithm() .tile(64, 64, 64) - .wave(ANY_INT, ANY_INT, 1) // ANY_INT → (1,4,1), (2,2,1), (4,1,1) - .warp(-1, -1, -1) // -1 same as ANY_INT → (16,16,32), (32,32,16) - .pipeline("*") // "*" → valid pipelines - .scheduler("*") // "*" → valid schedulers + .wave(ANY_INT, ANY_INT, 1) // ANY_INT -> (1,4,1), (2,2,1), (4,1,1) + .warp(-1, -1, -1) // -1 same as ANY_INT -> (16,16,32), (32,32,16) + .pipeline("*") // "*" -> valid pipelines + .scheduler("*") // "*" -> valid schedulers .epilogue("cshuffle"), "gfx942")); -// Raw: 3×2=6, arch filter removes 2 invalid → 4 valid kernels +// Raw: 3x2=6, arch filter removes 2 invalid -> 4 valid kernels // ============================================================================= // MAIN @@ -116,8 +116,8 @@ int main(int argc, char* argv[]) .pipeline("*") -> expands to valid pipelines = 1 .scheduler("*") -> expands to valid schedulers = 1 - Expanded: 3 × 2 = 6 configs, but arch filter validates each: - - wave×warp must divide tile: (4,1,1)×(32,32,16) invalid for 64x64 + Expanded: 3 x 2 = 6 configs, but arch filter validates each: + - wave x warp must divide tile: (4,1,1)x(32,32,16) invalid for 64x64 - Result: 4 valid kernels from wildcard + 1 explicit = 5 total )"; diff --git a/dispatcher/examples/gemm/cpp/07_gfx950_minimal.cpp b/dispatcher/examples/gemm/cpp/07_gfx950_minimal.cpp new file mode 100644 index 0000000000..7e62ad2e4f --- /dev/null +++ b/dispatcher/examples/gemm/cpp/07_gfx950_minimal.cpp @@ -0,0 +1,191 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 07: Minimal gfx950 (CDNA4 / MI350) GEMM + * + * Demonstrates the dispatcher working with gfx950-specific kernels: + * + * - fp16 GEMM with standard tile configs + * - fp8 GEMM with gfx950-extended warp tiles (16x16x128) + * - 160KB LDS: gfx950 doubles the LDS from 64KB to 160KB + * + * Build: cd dispatcher/build && cmake .. -DGPU_TARGETS=gfx950 && make gemm_07_gfx950_minimal + */ + +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// gfx950-targeted kernel declarations +// ============================================================================= + +DECL_KERNEL_SET(gfx950_gemm_kernels, + + // fp16 128x128x32 -- bread-and-butter config, works on all CDNA + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx950") + + // fp16 128x128x64 -- deeper K tile using more LDS + // LDS usage: 128*64*2 + 128*64*2 = 32768 bytes (fits 64KB, gfx950 has 160KB) + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx950") + + // fp16 64x64x32 -- small-tile variant for small problems + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) + .wave(2, 2, 1) + .warp(16, 16, 32) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx950")); + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 07: gfx950 Minimal GEMM", + "Demonstrates gfx950 (CDNA4 / MI350) dispatcher"); + args.add_flag("--list", "List registered kernels"); + args.add_flag("--list-verbose", "List registered kernels with full details"); + args.add_option("--M", "1024", "Problem M dimension"); + args.add_option("--N", "1024", "Problem N dimension"); + args.add_option("--K", "1024", "Problem K dimension"); + args.add_option("--arch", "gfx950", "GPU architecture (default: gfx950)"); + + if(!args.parse(argc, argv)) + return 0; + + std::string gfx_arch = args.get("--arch", "gfx950"); + + print_header("Example 07: gfx950 (CDNA4) Minimal GEMM"); + + // ========================================================================= + // Architecture info + // ========================================================================= + std::cout << "\ngfx950 (CDNA4 / MI350) highlights:\n"; + std::cout << " - 160KB LDS (up from 64KB on gfx942)\n"; + std::cout << " - Extended FP8 warp tiles: 16x16x128, 32x32x64\n"; + std::cout << " - Packed FP4 support (pk_fp4)\n"; + std::cout << " - Same warp configs as gfx942: [1,4,1], [2,2,1], [4,1,1]\n\n"; + + // ========================================================================= + // Register kernels + // ========================================================================= + std::cout << "Registering kernels for " << gfx_arch << "...\n"; + + Registry registry; + registry.set_name("gfx950_gemm"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + if(args.has("--list") || args.has("--list-verbose")) + { + std::cout << "\n"; + print_registered_kernels(registry, std::cout, args.has("--list-verbose")); + return 0; + } + + if(registry.size() == 0) + { + std::cerr << "ERROR: No kernels registered for " << gfx_arch << "!\n"; + std::cerr << " Did you build with -DGPU_TARGETS=gfx950?\n"; + return 1; + } + + // ========================================================================= + // Create Dispatcher + // ========================================================================= + Dispatcher dispatcher(®istry); + + // ========================================================================= + // Setup Problem + // ========================================================================= + const int M = args.get_int("--M", 1024); + const int N = args.get_int("--N", 1024); + const int K = args.get_int("--K", 1024); + + std::cout << "\nProblem: " << M << " x " << N << " x " << K << "\n"; + + Problem problem(M, N, K); + + using DataType = ck_tile::fp16_t; + GpuBuffer a_dev(M * K); + GpuBuffer b_dev(K * N); + GpuBuffer c_dev(M * N); + + std::vector a_host(M * K, DataType(1.0f)); + std::vector b_host(K * N, DataType(1.0f)); + a_dev.copy_from_host(a_host.data()); + b_dev.copy_from_host(b_host.data()); + c_dev.zero(); + + // ========================================================================= + // Select and Run + // ========================================================================= + auto selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << "ERROR: No suitable kernel found for " << M << "x" << N << "x" << K << "\n"; + return 1; + } + std::cout << " Selected: " << selected->get_name() << "\n"; + + float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr); + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << calculate_tflops(M, N, K, time_ms) << "\n"; + + // ========================================================================= + // Verify + // ========================================================================= + std::cout << "\nVerification:\n"; + std::vector c_host(M * N); + c_dev.copy_to_host(c_host.data()); + + const float expected = static_cast(K); + int errors = 0; + for(int i = 0; i < std::min(M * N, 1024); ++i) + { + if(std::abs(static_cast(c_host[i]) - expected) > 0.01f * expected + 1.0f) + ++errors; + } + + bool passed = (errors == 0); + std::cout << " Expected value: " << expected << "\n"; + std::cout << " Errors (first 1024 elements): " << errors << "\n"; + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + + print_separator(); + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/README.md b/dispatcher/examples/gemm/cpp/README.md index 1d81a90a0e..79d60d1198 100644 --- a/dispatcher/examples/gemm/cpp/README.md +++ b/dispatcher/examples/gemm/cpp/README.md @@ -29,14 +29,14 @@ cd examples ## Examples -| Example | Description | Complexity | -|---------|-------------|------------| -| [01_basic_gemm.cpp](01_basic_gemm.cpp) | Basic GEMM with declarative API, autofill, autocorrect | ★☆☆☆☆ | -| [02_multi_size.cpp](02_multi_size.cpp) | Wildcard expansion for multiple configurations | ★★☆☆☆ | -| [03_benchmark_validation.cpp](03_benchmark_validation.cpp) | Performance benchmarking with CPU reference validation | ★★☆☆☆ | -| [04_heuristics.cpp](04_heuristics.cpp) | Heuristic-based kernel selection | ★★★☆☆ | -| [05_json_export.cpp](05_json_export.cpp) | Registry JSON export for external tools | ★★☆☆☆ | -| [06_multi_registry.cpp](06_multi_registry.cpp) | Multiple registries with named kernel sets | ★★★☆☆ | +| Example | Description | +|---------|-------------| +| [01_basic_gemm.cpp](01_basic_gemm.cpp) | Basic GEMM with declarative API, autofill, autocorrect | +| [02_multi_size.cpp](02_multi_size.cpp) | Wildcard expansion for multiple configurations | +| [03_benchmark_validation.cpp](03_benchmark_validation.cpp) | Performance benchmarking with CPU reference validation | +| [04_heuristics.cpp](04_heuristics.cpp) | Heuristic-based kernel selection | +| [05_json_export.cpp](05_json_export.cpp) | Registry JSON export for external tools | +| [06_multi_registry.cpp](06_multi_registry.cpp) | Multiple registries with named kernel sets | ## Example Details @@ -225,5 +225,5 @@ DECL_KERNEL_SET(my_kernels, ## Related Documentation - [Python GEMM Examples](../python/README.md) -- [Convolution Examples](../../conv/cpp/README.md) +- [C++ Headers (GEMM + Grouped Conv)](../../../include/ck_tile/dispatcher/README.md) - [Main Dispatcher README](../../../README.md) diff --git a/dispatcher/examples/gemm/python/01_basic_gemm.py b/dispatcher/examples/gemm/python/01_basic_gemm.py index 93a78d24d1..8c23da89e2 100644 --- a/dispatcher/examples/gemm/python/01_basic_gemm.py +++ b/dispatcher/examples/gemm/python/01_basic_gemm.py @@ -7,41 +7,37 @@ Example 01: Basic GEMM with Multiple Kernels Demonstrates: -1. Declaring multiple kernel configurations -2. Printing all registered kernels -3. Running each kernel and validating output +1. Building a Registry with multiple kernel configurations +2. Parallel JIT compilation via registry.build() +3. Running each kernel and validating output against NumPy reference 4. Comparing performance across kernels -Complexity: ★★☆☆☆ - Usage: python3 01_basic_gemm.py - python3 01_basic_gemm.py --help python3 01_basic_gemm.py --dtype bf16 python3 01_basic_gemm.py --size 2048 + python3 01_basic_gemm.py --num-kernels 4 + python3 01_basic_gemm.py --workers 4 """ import sys +import time import argparse from pathlib import Path from dataclasses import dataclass -from typing import List sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) import numpy as np from ctypes_utils import ( KernelConfig, - setup_gemm_dispatcher, - cleanup_gemm, - reset_for_example, + Registry, + detect_gpu_arch, ) @dataclass class KernelSpec: - """Specification for a kernel configuration""" - name: str tile_m: int tile_n: int @@ -50,80 +46,37 @@ class KernelSpec: scheduler: str = "intrawave" -# Define multiple kernel configurations to test (50+ kernels) KERNEL_SPECS = [ - # Small tiles - compv3 + # Small tiles KernelSpec("small_64x64_k32", 64, 64, 32, "compv3"), KernelSpec("small_64x64_k64", 64, 64, 64, "compv3"), - # Small tiles - compv4 KernelSpec("small_64x64_v4_k32", 64, 64, 32, "compv4"), - KernelSpec("small_64x64_v4_k64", 64, 64, 64, "compv4"), - # Medium tiles - compv3 + # Medium tiles KernelSpec("med_128x128_k32", 128, 128, 32, "compv3"), KernelSpec("med_128x128_k64", 128, 128, 64, "compv3"), - KernelSpec("med_128x128_k128", 128, 128, 128, "compv3"), - # Medium tiles - compv4 KernelSpec("med_128x128_v4_k32", 128, 128, 32, "compv4"), KernelSpec("med_128x128_v4_k64", 128, 128, 64, "compv4"), - KernelSpec("med_128x128_v4_k128", 128, 128, 128, "compv4"), - # Rectangular tiles - compv3 + # Rectangular tiles KernelSpec("rect_64x128_k32", 64, 128, 32, "compv3"), KernelSpec("rect_64x128_k64", 64, 128, 64, "compv3"), KernelSpec("rect_128x64_k32", 128, 64, 32, "compv3"), KernelSpec("rect_128x64_k64", 128, 64, 64, "compv3"), - # Rectangular tiles - compv4 KernelSpec("rect_64x128_v4_k32", 64, 128, 32, "compv4"), - KernelSpec("rect_64x128_v4_k64", 64, 128, 64, "compv4"), KernelSpec("rect_128x64_v4_k32", 128, 64, 32, "compv4"), - KernelSpec("rect_128x64_v4_k64", 128, 64, 64, "compv4"), - # Large tiles - compv3 + # Large tiles KernelSpec("large_256x128_k32", 256, 128, 32, "compv3"), - KernelSpec("large_256x128_k64", 256, 128, 64, "compv3"), KernelSpec("large_128x256_k32", 128, 256, 32, "compv3"), - KernelSpec("large_128x256_k64", 128, 256, 64, "compv3"), KernelSpec("large_256x256_k32", 256, 256, 32, "compv3"), - KernelSpec("large_256x256_k64", 256, 256, 64, "compv3"), - # Large tiles - compv4 KernelSpec("large_256x128_v4_k32", 256, 128, 32, "compv4"), - KernelSpec("large_256x128_v4_k64", 256, 128, 64, "compv4"), - KernelSpec("large_128x256_v4_k32", 128, 256, 32, "compv4"), - KernelSpec("large_128x256_v4_k64", 128, 256, 64, "compv4"), KernelSpec("large_256x256_v4_k32", 256, 256, 32, "compv4"), - KernelSpec("large_256x256_v4_k64", 256, 256, 64, "compv4"), - # Interwave scheduler variants - KernelSpec("int_64x64_k32", 64, 64, 32, "compv3", "interwave"), + # Interwave scheduler KernelSpec("int_128x128_k32", 128, 128, 32, "compv3", "interwave"), - KernelSpec("int_128x128_k64", 128, 128, 64, "compv3", "interwave"), KernelSpec("int_256x128_k32", 256, 128, 32, "compv3", "interwave"), - # More tile_k variations - compv3 - KernelSpec("med_128x128_k16", 128, 128, 16, "compv3"), - KernelSpec("rect_64x128_k16", 64, 128, 16, "compv3"), - KernelSpec("rect_128x64_k16", 128, 64, 16, "compv3"), - # More tile_k variations - compv4 - KernelSpec("med_128x128_v4_k16", 128, 128, 16, "compv4"), - KernelSpec("rect_64x128_v4_k16", 64, 128, 16, "compv4"), - KernelSpec("rect_128x64_v4_k16", 128, 64, 16, "compv4"), - # Additional rectangular - KernelSpec("rect_32x64_k32", 32, 64, 32, "compv3"), - KernelSpec("rect_64x32_k32", 64, 32, 32, "compv3"), - KernelSpec("rect_32x128_k32", 32, 128, 32, "compv3"), - KernelSpec("rect_128x32_k32", 128, 32, 32, "compv3"), - # Additional compv4 variants - KernelSpec("rect_32x64_v4_k32", 32, 64, 32, "compv4"), - KernelSpec("rect_64x32_v4_k32", 64, 32, 32, "compv4"), - KernelSpec("rect_32x128_v4_k32", 32, 128, 32, "compv4"), - KernelSpec("rect_128x32_v4_k32", 128, 32, 32, "compv4"), ] -def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig: - """Create a KernelConfig from a spec""" - # Adjust warp tiles based on tile size - if spec.tile_m <= 64: - warp_m, warp_n = 16, 16 - else: - warp_m, warp_n = 32, 32 - +def spec_to_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig: + warp_m, warp_n = (16, 16) if spec.tile_m <= 64 else (32, 32) return KernelConfig( dtype_a=dtype, dtype_b=dtype, @@ -148,180 +101,118 @@ def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfi ) -def print_kernel_table(specs: List[KernelSpec], dtype: str): - """Print a formatted table of kernel configurations""" - print("\n" + "=" * 70) - print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)") - print("=" * 70) - print(f"\n {'#':<3} {'Name':<18} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}") - print(" " + "-" * 68) - - for i, spec in enumerate(specs, 1): - tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" - print( - f" {i:<3} {spec.name:<18} {tile:<14} {spec.pipeline:<10} {spec.scheduler:<12}" - ) - - print(" " + "-" * 68) - print(f" Data type: {dtype}") - - def main(): - parser = argparse.ArgumentParser( - description="Basic GEMM Example with Multiple Kernels", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - python3 01_basic_gemm.py # Default FP16 with 4 kernels - python3 01_basic_gemm.py --dtype bf16 # BF16 mode - python3 01_basic_gemm.py --size 2048 # Larger problem size - python3 01_basic_gemm.py --num-kernels 2 # Test only 2 kernels - """, - ) + parser = argparse.ArgumentParser(description="Basic GEMM with Multiple Kernels") + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--size", type=int, default=512, help="Problem size MxNxK") + parser.add_argument("--num-kernels", type=int, default=0, help="0 = all") parser.add_argument( - "--dtype", - default="fp16", - choices=["fp16", "bf16", "fp32"], - help="Data type (default: fp16)", - ) - parser.add_argument( - "--arch", - default="gfx942", - help="Target architecture (default: gfx942)", - ) - parser.add_argument( - "--size", - type=int, - default=512, - help="Problem size MxNxK (default: 512)", - ) - parser.add_argument( - "--num-kernels", - type=int, - default=0, - help="Number of kernels to test (0 = all)", + "--workers", type=int, default=0, help="Max parallel JIT workers (0 = auto)" ) args = parser.parse_args() - reset_for_example() - print("=" * 70) print("Example 01: Basic GEMM with Multiple Kernels") print("=" * 70) - # Select kernels to test specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS - # ========================================================================= - # Step 1: Print all kernel configurations - # ========================================================================= - print_kernel_table(specs, args.dtype) - - # ========================================================================= - # Step 2: Setup and test each kernel - # ========================================================================= - print("\n" + "=" * 70) - print(" RUNNING KERNELS") - print("=" * 70) - - np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 - M, N, K = args.size, args.size, args.size - - results = [] - - print(f"\n Problem size: {M}x{N}x{K}\n") + # Step 1: Build registry print( - f" {'#':<3} {'Name':<18} {'Tile':<14} {'Time (ms)':>10} {'TFLOPS':>10} {'Max Err':>10} {'Status':<8}" + f"\n {len(specs)} kernel configurations, dtype={args.dtype}, arch={args.arch}" ) - print(" " + "-" * 78) - - for i, spec in enumerate(specs, 1): - # Create unique test data per kernel - np.random.seed(42 + i * 1000) - A = (np.random.randn(M, K) * 0.1).astype(np_dtype) - B = (np.random.randn(K, N) * 0.1).astype(np_dtype) - - # Create config and setup dispatcher - config = create_kernel_config(spec, args.dtype, args.arch) - - setup = setup_gemm_dispatcher( - config=config, - registry_name=f"kernel_{spec.name}", - verbose=False, - auto_rebuild=True, + print(f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}") + print(" " + "-" * 64) + for i, s in enumerate(specs, 1): + print( + f" {i:<3} {s.name:<22} {s.tile_m}x{s.tile_n}x{s.tile_k:<6} {s.pipeline:<10} {s.scheduler:<12}" ) + reg = Registry(name="basic_gemm") + for s in specs: + reg.register_kernel(spec_to_config(s, args.dtype, args.arch)) + + # Step 2: Parallel JIT build via registry.build() + workers = args.workers if args.workers > 0 else None + print( + f"\n--- Parallel JIT Build ({len(specs)} kernels{f', workers={workers}' if workers else ''}) ---" + ) + + t0 = time.perf_counter() + setups = reg.build(verbose=False, max_workers=workers) + jit_build_s = time.perf_counter() - t0 + + built = sum(1 for s in setups if s.success) + print(f" Built: {built}/{len(specs)} kernels in {jit_build_s:.1f} s") + + if built == 0: + print(" ERROR: No kernels built") + return 1 + + # Step 3: Run each kernel and validate + print(f"\n--- Running Kernels (problem {args.size}x{args.size}x{args.size}) ---") + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + M = N = K = args.size + + np.random.seed(42) + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) + + print( + f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Time(ms)':>10} {'TFLOPS':>10} {'MaxErr':>10} {'Status':<6}" + ) + print(" " + "-" * 80) + + results = [] + for i, (spec, setup) in enumerate(zip(specs, setups), 1): tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" if not setup.success: print( - f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + f" {i:<3} {spec.name:<22} {tile:<14} {'---':>10} {'---':>10} {'---':>10} {'SKIP':<6}" ) - results.append((spec.name, False, 0, 0, 0)) - cleanup_gemm() + results.append((spec.name, False, 0.0, 0.0, 0.0)) continue - dispatcher = setup.dispatcher - - # Check if size is supported - if not dispatcher.is_supported(M, N, K): + disp = setup.dispatcher + if not disp.is_supported(M, N, K): print( - f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'SKIP':<8}" + f" {i:<3} {spec.name:<22} {tile:<14} {'---':>10} {'---':>10} {'---':>10} {'SKIP':<6}" ) - results.append((spec.name, False, 0, 0, 0)) - cleanup_gemm() + results.append((spec.name, False, 0.0, 0.0, 0.0)) continue - # Run GEMM - result = dispatcher.run(A, B, M, N, K) - - if not result.success: + res = disp.run(A, B, M, N, K) + if not res.success: print( - f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + f" {i:<3} {spec.name:<22} {tile:<14} {'---':>10} {'---':>10} {'---':>10} {'FAIL':<6}" ) - results.append((spec.name, False, 0, 0, 0)) - cleanup_gemm() + results.append((spec.name, False, 0.0, 0.0, 0.0)) continue - # Validate against NumPy reference - C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) - max_err = np.max(np.abs(result.output - C_ref)) - - # Check if within tolerance - passed = max_err < 1e-2 - status = "PASS" if passed else "FAIL" - + max_err = float(np.max(np.abs(res.output - C_ref))) + ok = max_err < 1e-2 + tag = "PASS" if ok else "FAIL" print( - f" {i:<3} {spec.name:<18} {tile:<14} {result.time_ms:>10.4f} {result.tflops:>10.2f} {max_err:>10.2e} {status:<8}" + f" {i:<3} {spec.name:<22} {tile:<14} {res.time_ms:>10.4f} {res.tflops:>10.2f} {max_err:>10.2e} {tag:<6}" ) - results.append((spec.name, passed, result.time_ms, result.tflops, max_err)) - - cleanup_gemm() - - # ========================================================================= - # Step 3: Summary - # ========================================================================= - print("\n" + "=" * 70) - print(" SUMMARY") - print("=" * 70) + results.append((spec.name, ok, res.time_ms, res.tflops, max_err)) + # Step 4: Summary passed = sum(1 for r in results if r[1]) failed = len(results) - passed + valid = [r for r in results if r[1]] - print(f"\n Results: {passed}/{len(results)} kernels passed") - print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}") - - if results: - valid_results = [r for r in results if r[1]] - if valid_results: - best = max(valid_results, key=lambda x: x[3]) - print(f"\n Best kernel: {best[0]} ({best[3]:.2f} TFLOPS)") - - if failed == 0: - print("\n *** ALL KERNELS PASSED ***") - else: - print(f"\n *** {failed} KERNELS FAILED ***") - + print("\n" + "=" * 70) + print(f" Results: {passed}/{len(results)} passed") + print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}") + print(f" JIT time: {jit_build_s:.1f} s (parallel)") + if valid: + best = max(valid, key=lambda x: x[3]) + print(f" Best: {best[0]} ({best[3]:.2f} TFLOPS)") + print(f" Status: {'PASS' if failed == 0 else 'FAIL'}") print("=" * 70) return 0 if failed == 0 else 1 diff --git a/dispatcher/examples/gemm/python/02_batch_gemm.py b/dispatcher/examples/gemm/python/02_batch_gemm.py index 039aba2790..745ec1c494 100644 --- a/dispatcher/examples/gemm/python/02_batch_gemm.py +++ b/dispatcher/examples/gemm/python/02_batch_gemm.py @@ -6,9 +6,7 @@ """ Example 02: Batch GEMM -Runs multiple GEMM operations with different sizes. - -Complexity: ★★☆☆☆ +Runs multiple GEMM operations with different sizes using JIT compilation. Usage: python3 02_batch_gemm.py @@ -25,9 +23,8 @@ import numpy as np from ctypes_utils import ( KernelConfig, - setup_gemm_dispatcher, - cleanup_gemm, - reset_for_example, + Registry, + detect_gpu_arch, ) @@ -55,20 +52,20 @@ Examples: help="Maximum problem size (default: 4096)", ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() - reset_for_example() - print("=" * 60) print("Example 02: Batch GEMM") print("=" * 60) # ========================================================================= - # Step 1: Setup dispatcher + # Step 1: JIT build dispatcher # ========================================================================= - print("\nStep 1: Setup Dispatcher") + print("\nStep 1: JIT Build Dispatcher") config = KernelConfig( dtype_a=args.dtype, @@ -80,19 +77,22 @@ Examples: gfx_arch=args.arch, ) - setup = setup_gemm_dispatcher(config, registry_name="batch_gemm", verbose=True) - if not setup.success: - print(f" ERROR: {setup.error}") + reg = Registry(name="batch_gemm") + reg.register_kernel(config) + + setups = reg.build(verbose=True) + if not setups or not setups[0].success: + error = setups[0].error if setups else "No kernels built" + print(f" ERROR: {error}") return 1 - dispatcher = setup.dispatcher + dispatcher = setups[0].dispatcher # ========================================================================= # Step 2: Run batch of different sizes # ========================================================================= print("\nStep 2: Run Batch") - # Generate sizes up to max_size all_sizes = [ (256, 256, 256), (512, 512, 512), @@ -135,9 +135,6 @@ Examples: avg_tflops = (total_ops / 1e12) / (total_time / 1000) print(f"\n Total: {total_time:.2f} ms, Average: {avg_tflops:.2f} TFLOPS") - # Cleanup - cleanup_gemm() - print("\n" + "=" * 60) print("Batch GEMM complete!") print("=" * 60) diff --git a/dispatcher/examples/gemm/python/03_benchmark.py b/dispatcher/examples/gemm/python/03_benchmark.py index bec1b7e2fb..508b3f8b35 100644 --- a/dispatcher/examples/gemm/python/03_benchmark.py +++ b/dispatcher/examples/gemm/python/03_benchmark.py @@ -6,9 +6,8 @@ """ Example 03: Benchmark -Performance benchmarking with compute-optimized kernel configuration. - -Complexity: ★★★☆☆ +Performance benchmarking with compute-optimized kernel configuration +using JIT compilation. Usage: python3 03_benchmark.py @@ -26,9 +25,8 @@ import numpy as np from ctypes_utils import ( KernelConfig, - setup_gemm_dispatcher, - cleanup_gemm, - reset_for_example, + Registry, + detect_gpu_arch, ) @@ -63,20 +61,20 @@ Examples: "--iterations", type=int, default=10, help="Benchmark iterations (default: 10)" ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() - reset_for_example() - print("=" * 60) print("Example 03: Benchmark") print("=" * 60) # ========================================================================= - # Step 1: Setup dispatcher with compute-optimized config + # Step 1: JIT build dispatcher with compute-optimized config # ========================================================================= - print("\nStep 1: Setup Dispatcher") + print("\nStep 1: JIT Build Dispatcher") config = KernelConfig( dtype_a=args.dtype, @@ -90,12 +88,16 @@ Examples: gfx_arch=args.arch, ) - setup = setup_gemm_dispatcher(config, registry_name="benchmark", verbose=True) - if not setup.success: - print(f" ERROR: {setup.error}") + reg = Registry(name="benchmark") + reg.register_kernel(config) + + setups = reg.build(verbose=True) + if not setups or not setups[0].success: + error = setups[0].error if setups else "No kernels built" + print(f" ERROR: {error}") return 1 - dispatcher = setup.dispatcher + dispatcher = setups[0].dispatcher # ========================================================================= # Step 2: Benchmark @@ -130,11 +132,9 @@ Examples: A = np.random.randn(M, K).astype(np_dtype) * 0.1 B = np.random.randn(K, N).astype(np_dtype) * 0.1 - # Warmup for _ in range(args.warmup): dispatcher.run(A, B, M, N, K) - # Benchmark times = [] for _ in range(args.iterations): result = dispatcher.run(A, B, M, N, K) @@ -150,9 +150,6 @@ Examples: f" {M:>4}x{N:>4}x{K:<4} | {min_time:>10.4f} | {avg_time:>10.4f} | {tflops:>10.2f}" ) - # Cleanup - cleanup_gemm() - # Summary print("\n" + "=" * 60) print("Summary") diff --git a/dispatcher/examples/gemm/python/04_validation.py b/dispatcher/examples/gemm/python/04_validation.py index 2fe54c53f7..d56621c3c8 100644 --- a/dispatcher/examples/gemm/python/04_validation.py +++ b/dispatcher/examples/gemm/python/04_validation.py @@ -6,9 +6,7 @@ """ Example 04: Validation -Validates GPU GEMM against NumPy reference. - -Complexity: ★★★☆☆ +Validates GPU GEMM against NumPy reference using JIT compilation. Usage: python3 04_validation.py @@ -26,9 +24,8 @@ import numpy as np from ctypes_utils import ( KernelConfig, Validator, - setup_gemm_dispatcher, - cleanup_gemm, - reset_for_example, + Registry, + detect_gpu_arch, ) @@ -56,20 +53,20 @@ Examples: "--atol", type=float, default=1e-2, help="Absolute tolerance (default: 1e-2)" ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() - reset_for_example() - print("=" * 60) print("Example 04: Validation") print("=" * 60) # ========================================================================= - # Step 1: Setup dispatcher + # Step 1: JIT build dispatcher # ========================================================================= - print("\nStep 1: Setup Dispatcher") + print("\nStep 1: JIT Build Dispatcher") config = KernelConfig( dtype_a=args.dtype, @@ -81,12 +78,16 @@ Examples: gfx_arch=args.arch, ) - setup = setup_gemm_dispatcher(config, registry_name="validation", verbose=True) - if not setup.success: - print(f" ERROR: {setup.error}") + reg = Registry(name="validation") + reg.register_kernel(config) + + setups = reg.build(verbose=True) + if not setups or not setups[0].success: + error = setups[0].error if setups else "No kernels built" + print(f" ERROR: {error}") return 1 - dispatcher = setup.dispatcher + dispatcher = setups[0].dispatcher # ========================================================================= # Step 2: Run validation tests @@ -139,9 +140,6 @@ Examples: print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | FAILED") failed += 1 - # Cleanup - cleanup_gemm() - # Summary print("\n" + "=" * 60) total = passed + failed diff --git a/dispatcher/examples/gemm/python/05_numpy_integration.py b/dispatcher/examples/gemm/python/05_numpy_integration.py index 493ce46d22..b0af5fa700 100644 --- a/dispatcher/examples/gemm/python/05_numpy_integration.py +++ b/dispatcher/examples/gemm/python/05_numpy_integration.py @@ -8,7 +8,6 @@ Example 05: NumPy Integration Shows how to create a GPU-accelerated matmul wrapper. -Complexity: ★★☆☆☆ Usage: python3 05_numpy_integration.py @@ -29,6 +28,7 @@ from ctypes_utils import ( setup_gemm_dispatcher, cleanup_gemm, reset_for_example, + detect_gpu_arch, ) @@ -70,7 +70,9 @@ Examples: help="Data type (default: fp16)", ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() diff --git a/dispatcher/examples/gemm/python/06_json_export.py b/dispatcher/examples/gemm/python/06_json_export.py index 9e062e507b..780032ce06 100644 --- a/dispatcher/examples/gemm/python/06_json_export.py +++ b/dispatcher/examples/gemm/python/06_json_export.py @@ -8,7 +8,6 @@ Example 06: JSON Export Exports registry configuration to JSON. -Complexity: ★★☆☆☆ Usage: python3 06_json_export.py @@ -28,6 +27,7 @@ from ctypes_utils import ( setup_gemm_dispatcher, cleanup_gemm, reset_for_example, + detect_gpu_arch, ) @@ -54,7 +54,9 @@ Examples: help="Data type (default: fp16)", ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() diff --git a/dispatcher/examples/gemm/python/07_stress_test.py b/dispatcher/examples/gemm/python/07_stress_test.py index 8160030631..620e66eeaf 100644 --- a/dispatcher/examples/gemm/python/07_stress_test.py +++ b/dispatcher/examples/gemm/python/07_stress_test.py @@ -18,7 +18,6 @@ This tests: - Multiple data types (fp16, bf16) - Different schedulers (intrawave, interwave) -Complexity: ★★★★☆ Usage: python3 07_stress_test.py @@ -43,6 +42,7 @@ from ctypes_utils import ( cleanup_gemm, reset_for_example, Validator, + detect_gpu_arch, ) @@ -413,8 +413,8 @@ Examples: ) parser.add_argument( "--arch", - default="gfx942", - help="Target architecture (default: gfx942)", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo, override with --arch gfxNNN)", ) args = parser.parse_args() diff --git a/dispatcher/examples/gemm/python/08_heuristics.py b/dispatcher/examples/gemm/python/08_heuristics.py index e2763c0513..acbf1b3ae0 100644 --- a/dispatcher/examples/gemm/python/08_heuristics.py +++ b/dispatcher/examples/gemm/python/08_heuristics.py @@ -19,7 +19,6 @@ Heuristic strategies: - Memory-bound: Optimize memory access for bandwidth-limited cases - Latency-focused: Minimize kernel launch overhead for small problems -Complexity: ★★★★☆ Usage: python3 08_heuristics.py @@ -43,6 +42,7 @@ from ctypes_utils import ( setup_gemm_dispatcher, cleanup_gemm, reset_for_example, + detect_gpu_arch, ) @@ -561,8 +561,8 @@ Examples: ) parser.add_argument( "--arch", - default="gfx942", - help="Target architecture (default: gfx942)", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo, override with --arch gfxNNN)", ) args = parser.parse_args() diff --git a/dispatcher/examples/gemm/python/09_multi_registry.py b/dispatcher/examples/gemm/python/09_multi_registry.py index 97cbce3497..5d9af239d4 100644 --- a/dispatcher/examples/gemm/python/09_multi_registry.py +++ b/dispatcher/examples/gemm/python/09_multi_registry.py @@ -8,7 +8,6 @@ Example 09: Multiple Registries Demonstrates multiple registries for different optimization targets. -Complexity: ★★★★★ Usage: python3 09_multi_registry.py @@ -30,6 +29,7 @@ from ctypes_utils import ( setup_gemm_dispatcher, cleanup_gemm, reset_for_example, + detect_gpu_arch, ) @@ -50,7 +50,9 @@ Examples: help="Data type (default: fp16)", ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() diff --git a/dispatcher/examples/gemm/python/10_advanced_benchmark.py b/dispatcher/examples/gemm/python/10_advanced_benchmark.py index e16e4e271f..b1462478d0 100644 --- a/dispatcher/examples/gemm/python/10_advanced_benchmark.py +++ b/dispatcher/examples/gemm/python/10_advanced_benchmark.py @@ -33,6 +33,7 @@ from ctypes_utils import ( setup_gemm_dispatcher, cleanup_gemm, reset_for_example, + detect_gpu_arch, ) @@ -69,7 +70,11 @@ def parse_args(): # Kernel configuration parser.add_argument("--dtype", default="fp16", help="Data type") parser.add_argument("--pipeline", default="compv4", help="Pipeline type") - parser.add_argument("--arch", default="gfx942", help="GPU architecture") + parser.add_argument( + "--arch", + default=detect_gpu_arch(), + help="GPU architecture (auto-detected from rocminfo)", + ) return parser.parse_args() diff --git a/dispatcher/examples/gemm/python/11_json_import.py b/dispatcher/examples/gemm/python/11_json_import.py index 06743af406..d19395e553 100644 --- a/dispatcher/examples/gemm/python/11_json_import.py +++ b/dispatcher/examples/gemm/python/11_json_import.py @@ -15,7 +15,6 @@ Key Features: - Use arch_filter validation on loaded configs - Export to C++ DECL_KERNEL_SET format -Complexity: ★★★☆☆ Usage: python3 11_json_import.py @@ -45,6 +44,7 @@ from ctypes_utils import ( # noqa: E402 cleanup_gemm, reset_for_example, validate_kernel_config, + detect_gpu_arch, ) # Sample JSON configuration (embedded for demonstration) @@ -141,8 +141,8 @@ Examples: ) parser.add_argument( "--arch", - default="gfx942", - help="Target GPU architecture (default: gfx942)", + default=detect_gpu_arch(), + help="Target GPU architecture (auto-detected from rocminfo, override with --arch gfxNNN)", ) args = parser.parse_args() @@ -236,13 +236,13 @@ Examples: else: invalid_count += 1 if invalid_count <= 3: # Show first 3 invalid - print(f"\n ✗ Invalid: {config.kernel_name()}") + print(f"\n FAIL Invalid: {config.kernel_name()}") for error in result.errors: print(f" Error: {error}") print("\n Validation Summary:") - print(f" ✓ Valid: {valid_count}") - print(f" ✗ Invalid: {invalid_count}") + print(f" OK Valid: {valid_count}") + print(f" FAIL Invalid: {invalid_count}") print(f" Total: {len(configs)}") # ========================================================================= @@ -275,12 +275,12 @@ Examples: disp_config, registry_name="json_import", verbose=False ) if setup.success: - print(" ✓ Dispatcher setup successful") + print(" OK Dispatcher setup successful") print( f" Kernel header: {setup.kernel_header.name if setup.kernel_header else 'N/A'}" ) else: - print(f" ⚠ Dispatcher setup: {setup.error}") + print(f" WARNING Dispatcher setup: {setup.error}") print(" (This is expected if kernels aren't generated)") # ========================================================================= diff --git a/dispatcher/examples/gemm/python/README.md b/dispatcher/examples/gemm/python/README.md index 0a83f3533f..07757b951b 100644 --- a/dispatcher/examples/gemm/python/README.md +++ b/dispatcher/examples/gemm/python/README.md @@ -295,5 +295,5 @@ Compilation time scales roughly linearly with kernel count. ## Related Documentation - [C++ GEMM Examples](../cpp/README.md) -- [Python Conv Examples](../../conv/python/README.md) +- [Python Utilities](../../../python/README.md) - [Main Dispatcher README](../../../README.md) diff --git a/dispatcher/examples/grouped_conv/cpp/01_basic_grouped_conv.cpp b/dispatcher/examples/grouped_conv/cpp/01_basic_grouped_conv.cpp new file mode 100644 index 0000000000..b503129c57 --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/01_basic_grouped_conv.cpp @@ -0,0 +1,203 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 01: Basic Grouped Convolution +// +// Demonstrates three declaration patterns (mirrors GEMM 01): +// 1. AUTOFILL - tile + pipeline only, wave/warp auto-filled +// 2. AUTOCORRECT - invalid wave(1,1,1) corrected to valid config +// 3. FULL - all parameters explicit (matches validated gfx942 config) +// +// Then runs the forward convolution on GPU and verifies output. +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_01_basic + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +// Three declaration patterns -- codegen auto-fills/auto-corrects as needed +DECL_GROUPED_CONV_KERNEL_SET( + basic_conv_kernels, + // Pattern 1: AUTOFILL - only tile + pipeline, rest auto-filled + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").scheduler("intrawave"), + "gfx950") + // Pattern 2: AUTOCORRECT - wave(1,1,1) invalid, corrected to (1,4,1) + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, 64, 64) + .wave(1, 1, 1) + .warp(16, 16, 32) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle") + .vector_sizes(4, 8, 8), + "gfx950") + // Pattern 3: FULL - all parameters explicit (validated config) + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, 128, 128) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle") + .vector_sizes(4, 8, 8) + .block_per_cu(1), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 01: Basic Grouped Convolution", + "Declaration patterns + GPU execution"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--size", "14", "Spatial size (H=W)"); + args.add_option("-n", "1", "Batch size"); + args.add_option("-g", "1", "Groups"); + args.add_option("-c", "64", "Input channels C"); + args.add_option("-k", "128", "Output channels K"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 01: Basic Grouped Convolution"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + int N = args.get_int("-n", 1); + int G = args.get_int("-g", 1); + int C = args.get_int("-c", 64); + int K = args.get_int("-k", 128); + int HW = args.get_int("--size", 14); + int Y = 3, X = 3; + + // Step 1: Show declared kernel sets + std::cout << "\nStep 1: Declared Kernel Sets\n"; + GroupedConvKernelSetRegistry::instance().print(); + + // Step 2: Register kernels + std::cout << "\nStep 2: Register Kernels\n"; + GroupedConvRegistry registry; + registry.set_name("basic_conv"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + // Step 3: Create dispatcher + std::cout << "\nStep 3: Create Dispatcher\n"; + GroupedConvDispatcher dispatcher(®istry); + + // Step 4: Build problem using CK Tile ConvParam + std::cout << "\nStep 4: Problem\n"; + auto problem = create_grouped_conv2d_problem(N, C, K, HW, HW, Y, X, 1, 1); + problem.op = GroupedConvOp::Forward; + print_grouped_conv_problem(problem); + + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(G), + static_cast(N), + static_cast(K), + static_cast(C), + {static_cast(Y), static_cast(X)}, + {static_cast(HW), static_cast(HW)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + ck_tile::HostTensor input_host(in_desc); + ck_tile::HostTensor weight_host(wei_desc); + ck_tile::HostTensor output_host(out_desc); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight_host); + + ck_tile::DeviceMem input_dev(input_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weight_dev(weight_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem output_dev(output_host.get_element_space_size_in_bytes()); + + input_dev.ToDevice(input_host.data()); + weight_dev.ToDevice(weight_host.data()); + + // Step 5: Select and run + std::cout << "\nStep 5: Select and Run\n"; + + auto* selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << " ERROR: No kernel found for problem!\n"; + return 1; + } + std::cout << " Selected: " << selected->name() << "\n"; + + float time_ms = dispatcher.run(input_dev.GetDeviceBuffer(), + weight_dev.GetDeviceBuffer(), + output_dev.GetDeviceBuffer(), + problem, + nullptr); + + double tflops = calculate_conv_tflops(problem, time_ms); + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Step 6: Verify + std::cout << "\nStep 6: Verify\n"; + output_dev.FromDevice(output_host.data()); + + size_t total = output_host.get_element_space_size(); + size_t nonzero = 0; + double checksum = 0.0; + for(size_t i = 0; i < total; ++i) + { + float v = static_cast(output_host.data()[i]); + if(v != 0.0f) + ++nonzero; + checksum += v; + } + + bool passed = nonzero > 0; + std::cout << " Output elements: " << total << "\n"; + std::cout << " Non-zero: " << nonzero << "/" << total + << (nonzero > 0 ? " (kernel produced output)" : " WARNING: all zeros!") << "\n"; + std::cout << " Checksum: " << std::fixed << std::setprecision(2) << checksum << "\n"; + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + + utils::print_separator(); + std::cout << "DECLARATION PATTERNS:\n"; + std::cout << " 1. AUTOFILL: tile + pipeline only, wave/warp auto-filled\n"; + std::cout << " 2. AUTOCORRECT: invalid wave(1,1,1) corrected\n"; + std::cout << " 3. FULL: all parameters explicit\n"; + utils::print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/02_all_directions.cpp b/dispatcher/examples/grouped_conv/cpp/02_all_directions.cpp new file mode 100644 index 0000000000..a2f2b9d560 --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/02_all_directions.cpp @@ -0,0 +1,216 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 02: All Convolution Directions +// +// Forward, backward-data, and backward-weight for 2D convolution, +// each executed on GPU with non-zero verification. +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_02_all_dirs + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +DECL_GROUPED_CONV_KERNEL_SET( + conv_fwd_2d, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").vector_sizes(4, 8, 8), + "gfx950")); + +DECL_GROUPED_CONV_KERNEL_SET( + conv_bwdd_2d, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2), + GroupedConvAlgo().tile(1, 128, 128).pipeline("compv3").vector_sizes(4, 8, 8), + "gfx950")); + +DECL_GROUPED_CONV_KERNEL_SET( + conv_bwdw_2d, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_weight").dims(2), + GroupedConvAlgo() + .tile(1, 128, 128) + .pipeline("compv3") + .memory_op("atomic_add") + .vector_sizes(4, 8, 8), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 02: All Convolution Directions", + "Forward/BwdData/BwdWeight with GPU execution and verification"); + args.add_option("--arch", "gfx950", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 02: All Convolution Directions"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + + GroupedConvRegistry registry; + registry.set_name("all_directions"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + GroupedConvDispatcher dispatcher(®istry); + + const int N = 1, G = 1, C = 64, K = 128, Hi = 14, Wi = 14, Y = 3, X = 3; + + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(G), + static_cast(N), + static_cast(K), + static_cast(C), + {static_cast(Y), static_cast(X)}, + {static_cast(Hi), static_cast(Wi)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + ck_tile::HostTensor input(in_desc); + ck_tile::HostTensor weight(wei_desc); + ck_tile::HostTensor output(out_desc); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight); + + ck_tile::DeviceMem input_dev(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weight_dev(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem output_dev(output.get_element_space_size_in_bytes()); + + input_dev.ToDevice(input.data()); + weight_dev.ToDevice(weight.data()); + + std::cout << "\n " << std::left << std::setw(12) << "Direction" << std::right << std::setw(10) + << "Time(ms)" << std::setw(10) << "TFLOPS" << std::setw(14) << "NonZero" + << std::setw(10) << "Status" << "\n"; + std::cout << " " << std::string(56, '-') << "\n"; + + bool all_pass = true; + + auto print_result = + [](const char* label, float time_ms, double tflops, size_t nz, size_t total, bool ok) { + std::cout << " " << std::left << std::setw(12) << label << std::right << std::fixed + << std::setprecision(4) << std::setw(10) << time_ms << std::setprecision(2) + << std::setw(10) << tflops << std::setw(14) + << (std::to_string(nz) + "/" + std::to_string(total)) << std::setw(10) + << (ok ? "OK" : "FAIL") << "\n"; + }; + + // Forward: run(X, W, Y) + { + auto problem = + create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::Forward); + float time_ms = dispatcher.run(input_dev.GetDeviceBuffer(), + weight_dev.GetDeviceBuffer(), + output_dev.GetDeviceBuffer(), + problem, + nullptr); + output_dev.FromDevice(output.data()); + size_t nz = 0; + for(size_t i = 0; i < output.get_element_space_size(); ++i) + if(static_cast(output.data()[i]) != 0.0f) + ++nz; + bool ok = nz > 0; + print_result("forward", + time_ms, + calculate_conv_tflops(problem, time_ms), + nz, + output.get_element_space_size(), + ok); + if(!ok) + all_pass = false; + } + + // Backward Data: run(dY, W, dX) + { + auto problem = + create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardData); + ck_tile::HostTensor dx_host(in_desc); + ck_tile::DeviceMem dx_dev(dx_host.get_element_space_size_in_bytes()); + float time_ms = dispatcher.run(output_dev.GetDeviceBuffer(), // dY (from forward pass) + weight_dev.GetDeviceBuffer(), // W + dx_dev.GetDeviceBuffer(), // dX (output) + problem, + nullptr); + dx_dev.FromDevice(dx_host.data()); + size_t nz = 0; + for(size_t i = 0; i < dx_host.get_element_space_size(); ++i) + if(static_cast(dx_host.data()[i]) != 0.0f) + ++nz; + bool ok = nz > 0; + print_result("bwd_data", + time_ms, + calculate_conv_tflops(problem, time_ms), + nz, + dx_host.get_element_space_size(), + ok); + if(!ok) + all_pass = false; + } + + // Backward Weight: run(X, dY, dW) + { + auto problem = create_grouped_conv2d_problem( + N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardWeight); + ck_tile::HostTensor dw_host(wei_desc); + ck_tile::DeviceMem dw_dev(dw_host.get_element_space_size_in_bytes()); + float time_ms = dispatcher.run(input_dev.GetDeviceBuffer(), // X + output_dev.GetDeviceBuffer(), // dY + dw_dev.GetDeviceBuffer(), // dW (output) + problem, + nullptr); + dw_dev.FromDevice(dw_host.data()); + size_t nz = 0; + for(size_t i = 0; i < dw_host.get_element_space_size(); ++i) + if(static_cast(dw_host.data()[i]) != 0.0f) + ++nz; + bool ok = nz > 0; + print_result("bwd_weight", + time_ms, + calculate_conv_tflops(problem, time_ms), + nz, + dw_host.get_element_space_size(), + ok); + if(!ok) + all_pass = false; + } + + utils::print_separator(); + std::cout << " Status: " << (all_pass ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return all_pass ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/03_benchmark_validation.cpp b/dispatcher/examples/grouped_conv/cpp/03_benchmark_validation.cpp new file mode 100644 index 0000000000..12bd87d1a4 --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/03_benchmark_validation.cpp @@ -0,0 +1,263 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 03: Benchmark and CPU-Reference Validation +// +// Runs a 2D grouped conv forward kernel on the GPU via dispatcher.run() +// and compares against the CK Tile host reference implementation. +// Exposes warmup/repeat/log_level as CLI args (matches example 20 pattern). +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_03_bench_val + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/host/reference/reference_grouped_conv_fwd.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; +using AccDataType = float; + +DECL_GROUPED_CONV_KERNEL_SET( + bench_kernels, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").vector_sizes(4, 8, 8), + "gfx950") + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 64, 64).pipeline("compv3").vector_sizes(4, 8, 8), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 03: Benchmark & Validation", + "GPU execution with CPU reference validation"); + args.add_option("-n", "1", "Batch size N"); + args.add_option("-g", "1", "Groups G"); + args.add_option("-c", "64", "Input channels C"); + args.add_option("-k", "128", "Output channels K"); + args.add_option("--size", "14", "Spatial size (H=W)"); + args.add_option("--warmup", "3", "Warmup iterations"); + args.add_option("--repeat", "10", "Benchmark iterations"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_flag("--no-verify", "Skip CPU validation"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 03: Grouped Conv Benchmark & Validation"); + + int N = args.get_int("-n", 1); + int G = args.get_int("-g", 1); + int C = args.get_int("-c", 64); + int K = args.get_int("-k", 128); + int Hi = args.get_int("--size", 14); + int Wi = Hi; + int Y = 3, X = 3; + int warmup = args.get_int("--warmup", 3); + int repeat = args.get_int("--repeat", 10); + bool verify = !args.has("--no-verify"); + std::string gfx_arch = args.get("--arch", "gfx950"); + + std::cout << "\nProblem: N=" << N << " G=" << G << " C=" << C << " K=" << K << " Hi=" << Hi + << " Wi=" << Wi << " Y=" << Y << " X=" << X << "\n"; + std::cout << "Benchmark: warmup=" << warmup << " repeat=" << repeat << "\n"; + + // Step 1: Setup tensors using CK Tile descriptors + std::cout << "\nStep 1: Setup tensors\n"; + + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(G), + static_cast(N), + static_cast(K), + static_cast(C), + {static_cast(Y), static_cast(X)}, + {static_cast(Hi), static_cast(Wi)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + ck_tile::HostTensor input(in_desc); + ck_tile::HostTensor weight(wei_desc); + ck_tile::HostTensor output_gpu(out_desc); + ck_tile::HostTensor output_cpu(out_desc); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight); + output_cpu.SetZero(); + + std::cout << " Input: " << input.get_element_space_size() << " elements\n"; + std::cout << " Weight: " << weight.get_element_space_size() << " elements\n"; + std::cout << " Output: " << output_gpu.get_element_space_size() << " elements\n"; + + // Step 2: CPU reference + if(verify) + { + std::cout << "\nStep 2: CPU Reference\n"; + + std::vector strides_v = {1, 1}; + std::vector dilations_v = {1, 1}; + std::vector left_pads_v = {1, 1}; + std::vector right_pads_v = {1, 1}; + + ck_tile::reference_grouped_conv_fwd<2, InDataType, WeiDataType, OutDataType>( + input, weight, output_cpu, strides_v, dilations_v, left_pads_v, right_pads_v); + + std::cout << " CPU ref[0..7]: "; + for(int i = 0; i < std::min(8, static_cast(output_cpu.get_element_space_size())); ++i) + std::cout << std::fixed << std::setprecision(4) + << static_cast(output_cpu.data()[i]) << " "; + std::cout << "\n"; + } + + // Step 3: GPU execution via dispatcher + std::cout << "\nStep 3: GPU Execution (via dispatcher.run)\n"; + + GroupedConvRegistry registry; + registry.set_name("bench_val"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + GroupedConvDispatcher dispatcher(®istry); + + auto problem = create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1); + problem.op = GroupedConvOp::Forward; + + auto* selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << " ERROR: No kernel found!\n"; + return 1; + } + std::cout << " Selected: " << selected->name() << "\n"; + + ck_tile::DeviceMem input_dev(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weight_dev(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem output_dev(output_gpu.get_element_space_size_in_bytes()); + + input_dev.ToDevice(input.data()); + weight_dev.ToDevice(weight.data()); + + float elapsed_ms = dispatcher.run(input_dev.GetDeviceBuffer(), + weight_dev.GetDeviceBuffer(), + output_dev.GetDeviceBuffer(), + problem, + nullptr); + + output_dev.FromDevice(output_gpu.data()); + + size_t total = output_gpu.get_element_space_size(); + std::cout << " GPU out[0..7]: "; + for(int i = 0; i < std::min(8, static_cast(total)); ++i) + std::cout << std::fixed << std::setprecision(4) << static_cast(output_gpu.data()[i]) + << " "; + std::cout << "\n"; + + size_t nonzero_gpu = 0; + double gpu_sum = 0.0; + for(size_t i = 0; i < total; ++i) + { + float v = static_cast(output_gpu.data()[i]); + if(v != 0.0f) + ++nonzero_gpu; + gpu_sum += v; + } + std::cout << " GPU checksum: " << std::fixed << std::setprecision(6) << gpu_sum << "\n"; + std::cout << " GPU non-zero: " << nonzero_gpu << "/" << total + << (nonzero_gpu > 0 ? " (kernel produced output)" : " WARNING: all zeros!") << "\n"; + + int Ho = static_cast(problem.Ho()); + int Wo = static_cast(problem.Wo()); + double flops = 2.0 * G * N * K * C * Y * X * Ho * Wo; + double tflops = flops / (elapsed_ms * 1e9); + + std::cout << " Time: " << std::fixed << std::setprecision(4) << elapsed_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Step 4: Validation + bool passed = true; + if(verify) + { + std::cout << "\nStep 4: Validation (GPU vs CPU)\n"; + + constexpr float rtol = 1e-2f; + constexpr float atol = 1e-2f; + + float max_diff = 0.0f; + float max_rel = 0.0f; + size_t max_diff_idx = 0; + size_t num_elements = output_gpu.get_element_space_size(); + size_t mismatches = 0; + + for(size_t i = 0; i < num_elements; ++i) + { + float gpu_val = static_cast(output_gpu.data()[i]); + float cpu_val = static_cast(output_cpu.data()[i]); + float diff = std::abs(gpu_val - cpu_val); + float tol = atol + rtol * std::abs(cpu_val); + float rel = diff / (std::abs(cpu_val) + 1e-6f); + if(diff > max_diff) + { + max_diff = diff; + max_diff_idx = i; + } + max_rel = std::max(max_rel, rel); + if(diff > tol) + ++mismatches; + } + + passed = (mismatches == 0); + + std::cout << " Side-by-side at worst element [" << max_diff_idx << "]:\n"; + std::cout << " GPU: " << std::fixed << std::setprecision(6) + << static_cast(output_gpu.data()[max_diff_idx]) + << " CPU: " << static_cast(output_cpu.data()[max_diff_idx]) + << " diff: " << std::scientific << max_diff << "\n"; + std::cout << " Elements: " << num_elements << "\n"; + std::cout << " Mismatches: " << mismatches << "/" << num_elements << "\n"; + std::cout << " Max abs diff: " << std::scientific << max_diff << "\n"; + std::cout << " Max rel diff: " << std::scientific << max_rel << "\n"; + std::cout << " Status: " << (passed ? "PASSED" : "FAILED") << "\n"; + } + + utils::print_separator(); + std::cout << "BENCHMARK & VALIDATION:\n"; + std::cout << " GPU kernel: " << (selected ? selected->name() : "none") << "\n"; + std::cout << " Performance: " << std::fixed << std::setprecision(2) << tflops + << " TFLOPS\n"; + std::cout << " CPU reference: reference_grouped_conv_fwd<2>()\n"; + std::cout << " Validation: " << (passed ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/04_registry_json.cpp b/dispatcher/examples/grouped_conv/cpp/04_registry_json.cpp new file mode 100644 index 0000000000..0e5a6d33be --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/04_registry_json.cpp @@ -0,0 +1,154 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 04: Heuristic Selection + JSON Export +// +// Demonstrates runtime kernel selection with heuristic ranking, +// GPU execution, and JSON registry export. +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_04_registry_json + +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +// Two tile configs for heuristic selection +DECL_GROUPED_CONV_KERNEL_SET( + heuristic_kernels, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").vector_sizes(4, 8, 8), + "gfx950") + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 64, 64).pipeline("compv3").vector_sizes(4, 8, 8), + "gfx950")); + +std::vector conv_heuristic(const GroupedConvProblem& problem) +{ + int64_t spatial = problem.Ho() * problem.Wo(); + if(spatial > 400) + return {"128x128", "64x64"}; + return {"64x64", "128x128"}; +} + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 04: Heuristic + JSON", + "Runtime kernel selection and JSON export"); + args.add_option("--arch", "gfx950", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 04: Heuristic Selection + JSON Export"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + + // Step 1: Register + std::cout << "\nStep 1: Register Kernels" << std::endl; + GroupedConvRegistry registry; + registry.set_name("heuristic_conv"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)" << std::endl; + + // Step 2: Heuristic dispatcher + std::cout << "\nStep 2: Heuristic Dispatcher" << std::endl; + GroupedConvDispatcher dispatcher(®istry); + dispatcher.set_strategy(GroupedConvDispatcher::SelectionStrategy::Heuristic); + dispatcher.set_heuristic(conv_heuristic); + + // Step 3: Select kernels (no GPU yet) + std::cout << "\nStep 3: Kernel Selection" << std::endl; + + auto problem = create_grouped_conv2d_problem(1, 64, 128, 14, 14, 3, 3, 1, 1); + + auto* selected = dispatcher.select_kernel(problem); + std::cout << " Selected: " << (selected ? selected->name() : "none") << std::endl; + + // Step 4: GPU execution + std::cout << "\nStep 4: GPU Execution" << std::endl; + + ck_tile::conv::ConvParam cp{ + 2, + static_cast(1), + static_cast(1), + static_cast(128), + static_cast(64), + {static_cast(3), static_cast(3)}, + {static_cast(14), static_cast(14)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + std::cout << " Creating tensors..." << std::endl; + auto in_d = ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(cp); + auto wei_d = ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(cp); + auto out_d = ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(cp); + + ck_tile::HostTensor input(in_d); + ck_tile::HostTensor weight(wei_d); + ck_tile::HostTensor output(out_d); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight); + + std::cout << " Allocating device memory..." << std::endl; + ck_tile::DeviceMem in_dev(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem wei_dev(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem out_dev(output.get_element_space_size_in_bytes()); + in_dev.ToDevice(input.data()); + wei_dev.ToDevice(weight.data()); + + std::cout << " Launching kernel..." << std::endl; + float time_ms = dispatcher.run(in_dev.GetDeviceBuffer(), + wei_dev.GetDeviceBuffer(), + out_dev.GetDeviceBuffer(), + problem, + nullptr); + + std::cout << " Reading back..." << std::endl; + out_dev.FromDevice(output.data()); + size_t nz = 0; + for(size_t i = 0; i < output.get_element_space_size(); ++i) + if(static_cast(output.data()[i]) != 0.0f) + ++nz; + + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms" + << std::endl; + std::cout << " TFLOPS: " << std::setprecision(2) << calculate_conv_tflops(problem, time_ms) + << std::endl; + std::cout << " NonZero: " << nz << "/" << output.get_element_space_size() << std::endl; + + // Step 5: JSON export + std::cout << "\nStep 5: JSON Export" << std::endl; + std::string json = registry.export_json(false); + std::cout << " JSON size: " << json.size() << " bytes" << std::endl; + + bool passed = nz > 0; + utils::print_separator(); + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/05_bwd_data.cpp b/dispatcher/examples/grouped_conv/cpp/05_bwd_data.cpp new file mode 100644 index 0000000000..35595bb14c --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/05_bwd_data.cpp @@ -0,0 +1,183 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 05: Backward Data with CPU Reference Validation +// +// Computes dX = ConvBwdData(dY, W) on GPU via dispatcher.run() +// and validates against ck_tile::reference_grouped_conv_bwd_data. +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_05_bwd_data + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +DECL_GROUPED_CONV_KERNEL_SET( + bwd_data_kernels, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2), + GroupedConvAlgo() + .tile(1, 128, 128) + .pipeline("compv3") + .scheduler("intrawave") + .vector_sizes(4, 8, 8), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 05: Backward Data Validation", + "dX = ConvBwdData(dY, W) with CPU reference"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("-n", "1", "Batch size"); + args.add_option("-c", "64", "Input channels"); + args.add_option("-k", "128", "Output channels"); + args.add_option("--size", "14", "Spatial size (H=W)"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 05: Backward Data with CPU Validation"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + int N = args.get_int("-n", 1), G = 1; + int C = args.get_int("-c", 64), K = args.get_int("-k", 128); + int Hi = args.get_int("--size", 14), Wi = Hi, Y = 3, X = 3; + + // Setup + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(G), + static_cast(N), + static_cast(K), + static_cast(C), + {static_cast(Y), static_cast(X)}, + {static_cast(Hi), static_cast(Wi)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + // dY (gradient from next layer) and W (weight) are inputs; dX is output + ck_tile::HostTensor dy(out_desc); + ck_tile::HostTensor weight(wei_desc); + ck_tile::HostTensor dx_gpu(in_desc); + ck_tile::HostTensor dx_cpu(in_desc); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(dy); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight); + dx_cpu.SetZero(); + + // CPU reference + std::cout << "\nStep 1: CPU Reference (bwd_data)\n"; + std::vector strides_v = {1, 1}; + std::vector dilations_v = {1, 1}; + std::vector left_pads_v = {1, 1}; + std::vector right_pads_v = {1, 1}; + + ck_tile::reference_grouped_conv_bwd_data<2, InDataType, WeiDataType, OutDataType>( + dx_cpu, weight, dy, strides_v, dilations_v, left_pads_v, right_pads_v); + std::cout << " CPU complete\n"; + + // GPU execution via dispatcher + std::cout << "\nStep 2: GPU Execution (via dispatcher.run)\n"; + + GroupedConvRegistry registry; + registry.set_name("bwd_data"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + GroupedConvDispatcher dispatcher(®istry); + + auto problem = + create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardData); + + auto* selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << " ERROR: No bwd_data kernel found!\n"; + return 1; + } + std::cout << " Selected: " << selected->name() << "\n"; + + ck_tile::DeviceMem dy_dev(dy.get_element_space_size_in_bytes()); + ck_tile::DeviceMem wei_dev(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dx_dev(dx_gpu.get_element_space_size_in_bytes()); + + dy_dev.ToDevice(dy.data()); + wei_dev.ToDevice(weight.data()); + + // dispatcher.run(dY, W, dX, problem) for bwd_data + float time_ms = dispatcher.run(dy_dev.GetDeviceBuffer(), + wei_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer(), + problem, + nullptr); + + dx_dev.FromDevice(dx_gpu.data()); + + double tflops = (time_ms > 0) ? calculate_conv_tflops(problem, time_ms) : 0; + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Validation + std::cout << "\nStep 3: Validation (GPU vs CPU)\n"; + + size_t num_elements = dx_gpu.get_element_space_size(); + float max_abs = 0, max_rel = 0; + size_t mismatches = 0; + constexpr float rtol = 5e-2f, atol = 5e-2f; + + for(size_t i = 0; i < num_elements; ++i) + { + float gv = static_cast(dx_gpu.data()[i]); + float cv = static_cast(dx_cpu.data()[i]); + float d = std::abs(gv - cv); + float r = d / (std::abs(cv) + 1e-6f); + max_abs = std::max(max_abs, d); + max_rel = std::max(max_rel, r); + if(d > atol + rtol * std::abs(cv)) + ++mismatches; + } + + bool passed = (mismatches == 0); + std::cout << " Elements: " << num_elements << "\n"; + std::cout << " Mismatches: " << mismatches << "\n"; + std::cout << " Max abs diff: " << std::scientific << max_abs << "\n"; + std::cout << " Max rel diff: " << std::scientific << max_rel << "\n"; + + utils::print_separator(); + std::cout << " dX = ConvBwdData(dY, W)\n"; + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/06_bwd_weight.cpp b/dispatcher/examples/grouped_conv/cpp/06_bwd_weight.cpp new file mode 100644 index 0000000000..41cb75aecf --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/06_bwd_weight.cpp @@ -0,0 +1,188 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 06: Backward Weight with CPU Reference Validation +// +// Computes dW = ConvBwdWeight(X, dY) on GPU via dispatcher.run() +// and validates against ck_tile::reference_grouped_conv_bwd_weight. +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_06_bwd_weight + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +DECL_GROUPED_CONV_KERNEL_SET( + bwd_weight_kernels, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_weight").dims(2), + GroupedConvAlgo() + .tile(1, 128, 128) + .pipeline("compv3") + .scheduler("intrawave") + .memory_op("atomic_add") + .vector_sizes(4, 8, 8), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 06: Backward Weight Validation", + "dW = ConvBwdWeight(X, dY) with CPU reference"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("-n", "1", "Batch size"); + args.add_option("-c", "64", "Input channels"); + args.add_option("-k", "128", "Output channels"); + args.add_option("--size", "14", "Spatial size (H=W)"); + args.add_option("--split-k", "1", "Split-K factor for bwd_weight (k_batch)"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 06: Backward Weight with CPU Validation"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + int N = args.get_int("-n", 1), G = 1; + int C = args.get_int("-c", 64), K = args.get_int("-k", 128); + int Hi = args.get_int("--size", 14), Wi = Hi, Y = 3, X = 3; + + // Setup + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(G), + static_cast(N), + static_cast(K), + static_cast(C), + {static_cast(Y), static_cast(X)}, + {static_cast(Hi), static_cast(Wi)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + // X (input) and dY (gradient) are inputs; dW is output + ck_tile::HostTensor input(in_desc); + ck_tile::HostTensor dy(out_desc); + ck_tile::HostTensor dw_gpu(wei_desc); + ck_tile::HostTensor dw_cpu(wei_desc); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(dy); + dw_cpu.SetZero(); + + // CPU reference + std::cout << "\nStep 1: CPU Reference (bwd_weight)\n"; + std::vector strides_v = {1, 1}; + std::vector dilations_v = {1, 1}; + std::vector left_pads_v = {1, 1}; + std::vector right_pads_v = {1, 1}; + + ck_tile::reference_grouped_conv_bwd_weight<2, InDataType, WeiDataType, OutDataType>( + input, dw_cpu, dy, strides_v, dilations_v, left_pads_v, right_pads_v); + std::cout << " CPU complete\n"; + + // GPU execution + std::cout << "\nStep 2: GPU Execution (via dispatcher.run)\n"; + + GroupedConvRegistry registry; + registry.set_name("bwd_weight"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + GroupedConvDispatcher dispatcher(®istry); + + auto problem = + create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardWeight); + problem.split_k = args.get_int("--split-k", 1); + + auto* selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << " ERROR: No bwd_weight kernel found!\n"; + return 1; + } + std::cout << " Selected: " << selected->name() << "\n"; + + ck_tile::DeviceMem in_dev(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dy_dev(dy.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dw_dev(dw_gpu.get_element_space_size_in_bytes()); + + in_dev.ToDevice(input.data()); + dy_dev.ToDevice(dy.data()); + if(problem.split_k > 1) + dw_dev.SetZero(); + + // dispatcher.run(X, dY, dW, problem) for bwd_weight + float time_ms = dispatcher.run(in_dev.GetDeviceBuffer(), + dy_dev.GetDeviceBuffer(), + dw_dev.GetDeviceBuffer(), + problem, + nullptr); + + dw_dev.FromDevice(dw_gpu.data()); + + double tflops = (time_ms > 0) ? calculate_conv_tflops(problem, time_ms) : 0; + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Validation + std::cout << "\nStep 3: Validation (GPU vs CPU)\n"; + + size_t num_elements = dw_gpu.get_element_space_size(); + float max_abs = 0, max_rel = 0; + size_t mismatches = 0; + constexpr float rtol = 5e-2f, atol = 5e-2f; + + for(size_t i = 0; i < num_elements; ++i) + { + float gv = static_cast(dw_gpu.data()[i]); + float cv = static_cast(dw_cpu.data()[i]); + float d = std::abs(gv - cv); + float r = d / (std::abs(cv) + 1e-6f); + max_abs = std::max(max_abs, d); + max_rel = std::max(max_rel, r); + if(d > atol + rtol * std::abs(cv)) + ++mismatches; + } + + bool passed = (mismatches == 0); + std::cout << " Elements: " << num_elements << "\n"; + std::cout << " Mismatches: " << mismatches << "\n"; + std::cout << " Max abs diff: " << std::scientific << max_abs << "\n"; + std::cout << " Max rel diff: " << std::scientific << max_rel << "\n"; + + utils::print_separator(); + std::cout << " dW = ConvBwdWeight(X, dY)\n"; + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/07_multi_tile_benchmark.cpp b/dispatcher/examples/grouped_conv/cpp/07_multi_tile_benchmark.cpp new file mode 100644 index 0000000000..5c95f2c45a --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/07_multi_tile_benchmark.cpp @@ -0,0 +1,226 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 07: Multi-Tile Benchmark +// +// Benchmarks multiple tile configurations across ResNet-like problem sizes. +// Exposes warmup, repeat, and init method as CLI args (matching CK Tile +// example 20 patterns). +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_07_benchmark + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +// Multiple tile configurations for benchmarking +DECL_GROUPED_CONV_KERNEL_SET( + benchmark_tiles, + // Small tile - compv3 + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, 64, 64) + .wave(1, 4, 1) + .warp(16, 16, 32) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle") + .vector_sizes(4, 8, 8) + .block_per_cu(1), + "gfx950") + // Medium tile - compv3 + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, 128, 128) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle") + .vector_sizes(4, 8, 8) + .block_per_cu(1), + "gfx950") + // Large tile - compv4 with double smem buffer + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, 256, 256) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv4") + .scheduler("intrawave") + .epilogue("cshuffle") + .vector_sizes(4, 8, 8) + .block_per_cu(1), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 07: Multi-Tile Benchmark", + "Multiple tiles across ResNet-like problem sizes"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--warmup", "5", "Warmup iterations (passed to stream_config)"); + args.add_option("--repeat", "20", "Benchmark iterations (passed to stream_config)"); + args.add_option("--init", "0", "Init method: 0=random, 1=linear, 2=constant(1)"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 07: Multi-Tile Benchmark"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + int warmup = args.get_int("--warmup", 5); + int repeat = args.get_int("--repeat", 20); + int init_method = args.get_int("--init", 0); + + std::cout << "\n Config: warmup=" << warmup << " repeat=" << repeat << " init=" << init_method + << "\n"; + + GroupedConvRegistry registry; + registry.set_name("benchmark"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + GroupedConvDispatcher dispatcher(®istry); + + // ResNet-like problem sizes + struct BenchProblem + { + const char* label; + int N, C, K, Hi, Wi, Y, X; + }; + + BenchProblem problems[] = { + {"ResNet-stage2", 1, 64, 64, 56, 56, 3, 3}, + {"ResNet-stage3", 1, 128, 128, 28, 28, 3, 3}, + {"ResNet-stage4", 1, 256, 256, 14, 14, 3, 3}, + {"ResNet-stage5", 1, 512, 512, 7, 7, 3, 3}, + {"Pointwise-1x1", 1, 256, 256, 56, 56, 1, 1}, + {"Batch-8", 8, 64, 128, 56, 56, 3, 3}, + }; + + std::cout << "\n " << std::left << std::setw(16) << "Problem" << std::right << std::setw(5) + << "N" << std::setw(5) << "C" << std::setw(5) << "K" << std::setw(5) << "H" + << std::setw(5) << "W" << std::setw(4) << "F" << std::setw(10) << "Time(ms)" + << std::setw(10) << "TFLOPS" << std::setw(10) << "Status" << "\n"; + std::cout << " " << std::string(74, '-') << "\n"; + + bool all_pass = true; + for(const auto& bp : problems) + { + auto problem = + create_grouped_conv2d_problem(bp.N, bp.C, bp.K, bp.Hi, bp.Wi, bp.Y, bp.X, 1, 1); + problem.op = GroupedConvOp::Forward; + + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(1), + static_cast(bp.N), + static_cast(bp.K), + static_cast(bp.C), + {static_cast(bp.Y), static_cast(bp.X)}, + {static_cast(bp.Hi), static_cast(bp.Wi)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + ck_tile::HostTensor input(in_desc); + ck_tile::HostTensor weight(wei_desc); + ck_tile::HostTensor output(out_desc); + + switch(init_method) + { + case 1: + ck_tile::FillMonotonicSeq{0.0f, 0.001f}(input); + ck_tile::FillMonotonicSeq{0.0f, 0.001f}(weight); + break; + case 2: + ck_tile::FillConstant{1.0f}(input); + ck_tile::FillConstant{1.0f}(weight); + break; + default: + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight); + break; + } + ck_tile::DeviceMem in_dev(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem wei_dev(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem out_dev(output.get_element_space_size_in_bytes()); + + in_dev.ToDevice(input.data()); + wei_dev.ToDevice(weight.data()); + + float time_ms = 0; + bool ok = false; + try + { + time_ms = dispatcher.run(in_dev.GetDeviceBuffer(), + wei_dev.GetDeviceBuffer(), + out_dev.GetDeviceBuffer(), + problem, + nullptr); + + out_dev.FromDevice(output.data()); + size_t nz = 0; + for(size_t j = 0; j < output.get_element_space_size(); ++j) + if(static_cast(output.data()[j]) != 0.0f) + ++nz; + ok = nz > 0; + } + catch(const std::exception&) + { + ok = false; + } + + double tflops = (time_ms > 0) ? calculate_conv_tflops(problem, time_ms) : 0; + + std::string filter_str = std::to_string(bp.Y) + "x" + std::to_string(bp.X); + std::cout << " " << std::left << std::setw(16) << bp.label << std::right << std::setw(5) + << bp.N << std::setw(5) << bp.C << std::setw(5) << bp.K << std::setw(5) << bp.Hi + << std::setw(5) << bp.Wi << std::setw(4) << filter_str << std::fixed + << std::setprecision(4) << std::setw(10) << time_ms << std::setprecision(2) + << std::setw(10) << tflops << std::setw(10) << (ok ? "OK" : "FAIL") << "\n"; + if(!ok) + all_pass = false; + } + + utils::print_separator(); + std::cout << " Warmup: " << warmup << ", Repeat: " << repeat << ", Init: " << init_method + << "\n"; + std::cout << " Status: " << (all_pass ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return all_pass ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/python/01_basic_grouped_conv.py b/dispatcher/examples/grouped_conv/python/01_basic_grouped_conv.py new file mode 100644 index 0000000000..46f57b3879 --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/01_basic_grouped_conv.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 01: Basic Grouped Convolution + +Demonstrates: +1. Three kernel configuration patterns (minimal, explicit, full ConvConfigBase) +2. Adding kernels to a registry +3. Validation and auto-correction +4. JIT compilation via registry.build() +5. GPU execution with CPU reference verification + +Usage: + python3 01_basic_grouped_conv.py + python3 01_basic_grouped_conv.py --variant bwd_data + python3 01_basic_grouped_conv.py --arch gfx942 +""" + +import sys +import argparse +import time +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + validate_grouped_conv_config, + auto_correct_grouped_conv_config, + detect_gpu_arch, +) + + +def cpu_conv2d_fwd(inp, wei, prob): + """Naive CPU reference: 2D forward, NHWGC layout.""" + N, Hi, Wi, G, Cpg = inp.shape + _, Kpg, Y, X, _ = wei.shape + Ho, Wo = prob.Ho, prob.Wo + out = np.zeros((N, Ho, Wo, G, Kpg), dtype=np.float32) + for n in range(N): + for g in range(G): + for ho in range(Ho): + for wo in range(Wo): + for k in range(Kpg): + s = 0.0 + for y in range(Y): + for x in range(X): + hi = ( + ho * prob.stride_h + - prob.pad_h + + y * prob.dilation_h + ) + wi = ( + wo * prob.stride_w + - prob.pad_w + + x * prob.dilation_w + ) + if 0 <= hi < Hi and 0 <= wi < Wi: + for c in range(Cpg): + s += float(inp[n, hi, wi, g, c]) * float( + wei[g, k, y, x, c] + ) + out[n, ho, wo, g, k] = s + return out + + +def main(): + parser = argparse.ArgumentParser(description="Basic Grouped Conv Example") + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument( + "--variant", default="forward", choices=["forward", "bwd_data", "bwd_weight"] + ) + parser.add_argument("--ndim", type=int, default=2, choices=[2, 3]) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument( + "--workers", type=int, default=0, help="Max JIT workers (0=auto)" + ) + args = parser.parse_args() + + print("=" * 70) + print("Example 01: Basic Grouped Convolution") + print("=" * 70) + + # ========================================================================= + # Step 1: Three kernel configuration patterns + # ========================================================================= + print("\n--- Step 1: Kernel Configuration Patterns ---") + + # Pattern 1: MINIMAL -- only variant/dtype/arch, everything else auto-filled + config_minimal = GroupedConvKernelConfig( + variant=args.variant, + ndim_spatial=args.ndim, + arch=args.arch, + dtype=args.dtype, + ) + print("\n Pattern 1: MINIMAL (defaults auto-filled)") + config_minimal.print_config(indent=" ") + + # Pattern 2: EXPLICIT tile/wave/warp -- user controls tiling strategy + config_explicit = GroupedConvKernelConfig( + variant=args.variant, + ndim_spatial=args.ndim, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + ) + print("\n Pattern 2: EXPLICIT tile/wave/warp") + config_explicit.print_config(indent=" ") + + # Pattern 3: FULL ConvConfigBase -- every parameter specified + config_full = GroupedConvKernelConfig( + variant=args.variant, + ndim_spatial=args.ndim, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + num_wave_groups=1, + num_groups_to_merge=1, + ) + print("\n Pattern 3: FULL (all ConvConfigBase fields)") + config_full.print_config(indent=" ") + + # ========================================================================= + # Step 2: Build a registry with multiple configs + # ========================================================================= + print("\n--- Step 2: Build Registry ---") + registry = GroupedConvRegistry("basic_conv") + registry.add(config_minimal) + registry.add(config_explicit) + registry.add(config_full) + registry.print_registry() + + # ========================================================================= + # Step 3: Validate and auto-correct + # ========================================================================= + print("\n--- Step 3: Validate & Auto-Correct ---") + for i, cfg in enumerate(registry.kernels): + result = validate_grouped_conv_config(cfg.to_dict()) + if result.is_valid: + print(f" Config [{i}] {cfg.tile_str}: VALID") + else: + print(f" Config [{i}] {cfg.tile_str}: needs correction") + corrected, result = auto_correct_grouped_conv_config(cfg.to_dict()) + print(f" After correction: valid={result.is_valid}") + + # ========================================================================= + # Step 4: JIT compile via registry.build() + # ========================================================================= + print("\n--- Step 4: JIT Build (via registry.build()) ---") + + # Use only the first config for the actual GPU run + jit_reg = GroupedConvRegistry("jit") + jit_reg.add(config_minimal) + + workers = args.workers if args.workers > 0 else None + t0 = time.perf_counter() + runners = jit_reg.build(verbose=False, max_workers=workers) + jit_build_s = time.perf_counter() - t0 + + key = (args.variant, args.ndim) + if key not in runners: + print(" JIT build failed") + return 1 + runner = runners[key] + print(f" JIT build: {jit_build_s:.3f} s") + print(f" Library: {runner.library_path}") + print(f" Kernels: {runner.lib.kernel_names()}") + + # ========================================================================= + # Step 5: Define problem + GPU execution + # ========================================================================= + print("\n--- Step 5: GPU Execution ---") + prob = GroupedConvProblem( + N=1, + C=64, + K=128, + Hi=16, + Wi=16, + Y=3, + X=3, + stride_h=1, + stride_w=1, + pad_h=1, + pad_w=1, + direction=args.variant, + ) + prob.print_problem() + + inp = np.random.uniform(-0.5, 0.5, prob.input_shape()).astype(np.float16) + wei = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(np.float16) + + res = runner.run(inp, wei, prob) + if not res.success: + print(f" GPU execution failed: {res.error}") + runner.cleanup() + return 1 + + print(f" Time: {res.time_ms:.4f} ms") + print(f" TFLOPS: {res.tflops:.2f}") + print( + f" Output: shape={res.output.shape}, range=[{res.output.min():.3f}, {res.output.max():.3f}]" + ) + + # ========================================================================= + # Step 6: CPU reference (forward 2D only) + # ========================================================================= + verified = False + if args.variant == "forward" and args.ndim == 2: + print("\n--- Step 6: CPU Reference Verification ---") + ref = cpu_conv2d_fwd(inp, wei, prob) + gpu_f32 = res.output.astype(np.float32) + diff = np.abs(gpu_f32 - ref) + max_abs = diff.max() + max_rel = (diff / (np.abs(ref) + 1e-6)).max() + match = np.allclose(gpu_f32, ref, atol=0.05, rtol=0.05) + print(f" max_abs_diff: {max_abs:.6f}") + print(f" max_rel_diff: {max_rel:.6f}") + print(f" Match: {match}") + verified = match + + runner.cleanup() + + # Summary + print("\n" + "=" * 70) + status = ( + "PASS" if res.success and (verified or args.variant != "forward") else "FAIL" + ) + print(f" Status: {status}") + print( + f" {config_minimal.name} | {prob.gflops:.2f} GFLOPs | {res.tflops:.2f} TFLOPS" + ) + print(f" JIT build time: {jit_build_s:.3f} s") + print(f" Registry: {len(registry)} configs (3 patterns demonstrated)") + print("=" * 70) + return 0 if status == "PASS" else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/grouped_conv/python/02_forward.py b/dispatcher/examples/grouped_conv/python/02_forward.py new file mode 100644 index 0000000000..8f59db05a1 --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/02_forward.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 02: Forward Convolution (2D + 3D) + +Declares forward kernels with explicit tile/wave/warp/pipeline parameters, +builds a registry, JIT compiles, runs on GPU, and validates against CPU reference. + +Usage: + python3 02_forward.py + python3 02_forward.py --arch gfx942 +""" + +import sys +import argparse +import time +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + detect_gpu_arch, +) + + +def cpu_conv2d_fwd(inp, wei, prob): + """Naive CPU reference: 2D forward, NHWGC layout.""" + N, Hi, Wi, G, C = inp.shape + _, Kpg, Y, X, _ = wei.shape + Ho, Wo = prob.Ho, prob.Wo + out = np.zeros((N, Ho, Wo, G, Kpg), dtype=np.float32) + for n in range(N): + for g in range(G): + for ho in range(Ho): + for wo in range(Wo): + for k in range(Kpg): + s = 0.0 + for y in range(Y): + for x in range(X): + hi = ho * prob.stride_h - prob.pad_h + y + wi = wo * prob.stride_w - prob.pad_w + x + if 0 <= hi < Hi and 0 <= wi < Wi: + for c in range(C): + s += float(inp[n, hi, wi, g, c]) * float( + wei[g, k, y, x, c] + ) + out[n, ho, wo, g, k] = s + return out + + +def main(): + parser = argparse.ArgumentParser(description="Forward Convolution (2D + 3D)") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument( + "--workers", type=int, default=0, help="Max JIT workers (0=auto)" + ) + args = parser.parse_args() + + arch = args.arch + print("=" * 70) + print("Example 02: Forward Convolution (2D + 3D)") + print("=" * 70) + print(f"\n Arch: {arch}, Dtype: {args.dtype}") + + # ========================================================================= + # Step 1: Declare forward kernels with explicit parameters + # ========================================================================= + print("\n--- Step 1: Declare Forward Kernels ---") + reg = GroupedConvRegistry("forward_conv") + + # Forward 2D: compv4, 128x128 tile, wave 2x2x1, warp 32x32x16 + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # Forward 3D: compv3, 64x64 tile, wave 1x4x1, warp 16x16x32 + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=3, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + reg.print_registry() + + # ========================================================================= + # Step 2: JIT build via registry + # ========================================================================= + print("\n--- Step 2: JIT Build ---") + workers = args.workers if args.workers > 0 else None + t0 = time.perf_counter() + runners = reg.build(verbose=False, max_workers=workers) + jit_s = time.perf_counter() - t0 + print(f" Built {len(runners)} runners in {jit_s:.1f}s") + + for key in [("forward", 2), ("forward", 3)]: + tag = "OK" if key in runners else "FAILED" + print(f" {key[0]} {key[1]}D: {tag}") + + if ("forward", 2) not in runners: + print(" ERROR: forward 2D JIT failed") + return 1 + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # ========================================================================= + # Step 3: Forward 2D -- GPU + CPU reference + # ========================================================================= + print("\n--- Step 3: Forward 2D ---") + prob_2d = GroupedConvProblem( + N=1, C=64, K=64, Hi=8, Wi=8, Y=3, X=3, pad_h=1, pad_w=1, direction="forward" + ) + prob_2d.print_problem() + + x = np.random.uniform(-0.5, 0.5, prob_2d.input_shape()).astype(np_dtype) + w = np.random.uniform(-0.5, 0.5, prob_2d.weight_shape()).astype(np_dtype) + + res = runners[("forward", 2)].run(x, w, prob_2d) + print(f" Time: {res.time_ms:.4f} ms") + print(f" TFLOPS: {res.tflops:.2f}") + print( + f" Output: shape={res.output.shape}, nonzero={np.count_nonzero(res.output)}/{res.output.size}" + ) + + ref = cpu_conv2d_fwd(x, w, prob_2d) + diff = np.abs(res.output.astype(np.float32) - ref) + match_2d = np.allclose(res.output.astype(np.float32), ref, atol=0.05) + print(f" CPU ref: max_abs={diff.max():.6f}, match={match_2d}") + + # ========================================================================= + # Step 4: Forward 3D -- GPU + non-zero check + # ========================================================================= + ok_3d = True + if ("forward", 3) in runners: + print("\n--- Step 4: Forward 3D ---") + prob_3d = GroupedConvProblem( + N=1, + C=64, + K=64, + Di=8, + Hi=8, + Wi=8, + Z=3, + Y=3, + X=3, + pad_d=1, + pad_h=1, + pad_w=1, + direction="forward", + ) + prob_3d.print_problem() + + x3 = np.random.uniform(-0.5, 0.5, prob_3d.input_shape()).astype(np_dtype) + w3 = np.random.uniform(-0.5, 0.5, prob_3d.weight_shape()).astype(np_dtype) + + res3 = runners[("forward", 3)].run(x3, w3, prob_3d) + nz = np.count_nonzero(res3.output) + ok_3d = res3.success and nz > 0 + print(f" Time: {res3.time_ms:.4f} ms") + print(f" TFLOPS: {res3.tflops:.2f}") + print(f" NonZero: {nz}/{res3.output.size}") + + for r in runners.values(): + r.cleanup() + + passed = res.success and match_2d and ok_3d + print("\n" + "=" * 70) + print(f" Forward 2D: {'PASS' if match_2d else 'FAIL'} (CPU validated)") + print(f" Forward 3D: {'PASS' if ok_3d else 'FAIL'} (non-zero check)") + print(f" JIT build: {jit_s:.1f}s") + print(f" Status: {'PASS' if passed else 'FAIL'}") + print("=" * 70) + return 0 if passed else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/grouped_conv/python/03_bwd_data.py b/dispatcher/examples/grouped_conv/python/03_bwd_data.py new file mode 100644 index 0000000000..a000ba7c96 --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/03_bwd_data.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 03: Backward Data Convolution (2D + 3D) + +dX = ConvBwdData(dY, W) + +Declares backward-data kernels with explicit parameters, +builds a registry, JIT compiles, runs on GPU, and validates +against a CPU reference. + +Usage: + python3 03_bwd_data.py +""" + +import sys +import argparse +import time +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + detect_gpu_arch, +) + + +def cpu_conv2d_bwd_data(dy, wei, prob): + """CPU ref: compute dX from dY and W.""" + N, Ho, Wo, G, Kpg = dy.shape + _, _, Y, X, C = wei.shape + Hi, Wi = prob.Hi, prob.Wi + dx = np.zeros((N, Hi, Wi, G, C), dtype=np.float32) + for n in range(N): + for g in range(G): + for hi in range(Hi): + for wi in range(Wi): + for c in range(C): + s = 0.0 + for y in range(Y): + for x in range(X): + ho = hi + prob.pad_h - y + wo = wi + prob.pad_w - x + if ho % prob.stride_h == 0 and wo % prob.stride_w == 0: + ho //= prob.stride_h + wo //= prob.stride_w + if 0 <= ho < Ho and 0 <= wo < Wo: + for k in range(Kpg): + s += float(dy[n, ho, wo, g, k]) * float( + wei[g, k, y, x, c] + ) + dx[n, hi, wi, g, c] = s + return dx + + +def main(): + parser = argparse.ArgumentParser(description="Backward Data (2D + 3D)") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--workers", type=int, default=0) + args = parser.parse_args() + + arch = args.arch + print("=" * 70) + print("Example 03: Backward Data Convolution (2D + 3D)") + print("=" * 70) + print(f"\n Arch: {arch}, Dtype: {args.dtype}") + print(" dX = ConvBwdData(dY, W)") + + # ========================================================================= + # Step 1: Declare bwd_data kernels + # ========================================================================= + print("\n--- Step 1: Declare BwdData Kernels ---") + reg = GroupedConvRegistry("bwd_data_conv") + + # BwdData 2D: compv3, 128x128 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_data", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # BwdData 3D: compv3, 64x64 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_data", + ndim_spatial=3, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + reg.print_registry() + + # ========================================================================= + # Step 2: JIT build + # ========================================================================= + print("\n--- Step 2: JIT Build ---") + workers = args.workers if args.workers > 0 else None + t0 = time.perf_counter() + runners = reg.build(verbose=False, max_workers=workers) + jit_s = time.perf_counter() - t0 + print(f" Built {len(runners)} runners in {jit_s:.1f}s") + + if ("bwd_data", 2) not in runners: + print(" ERROR: bwd_data 2D JIT failed") + return 1 + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # ========================================================================= + # Step 3: BwdData 2D -- GPU + CPU reference + # ========================================================================= + print("\n--- Step 3: Backward Data 2D ---") + prob = GroupedConvProblem( + N=1, C=32, K=32, Hi=8, Wi=8, Y=3, X=3, pad_h=1, pad_w=1, direction="bwd_data" + ) + prob.print_problem() + + dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(np_dtype) + w = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(np_dtype) + + res = runners[("bwd_data", 2)].run(dy, w, prob) + print(f" Time: {res.time_ms:.4f} ms") + print(f" TFLOPS: {res.tflops:.2f}") + print(f" NonZero: {np.count_nonzero(res.output)}/{res.output.size}") + + ref = cpu_conv2d_bwd_data(dy, w, prob) + diff = np.abs(res.output.astype(np.float32) - ref) + match_2d = np.allclose(res.output.astype(np.float32), ref, atol=0.1) + print(f" CPU ref: max_abs={diff.max():.6f}, match={match_2d}") + + # ========================================================================= + # Step 4: BwdData 3D -- GPU + non-zero check + # ========================================================================= + ok_3d = True + if ("bwd_data", 3) in runners: + print("\n--- Step 4: Backward Data 3D ---") + prob3 = GroupedConvProblem( + N=1, + C=32, + K=32, + Di=6, + Hi=6, + Wi=6, + Z=3, + Y=3, + X=3, + pad_d=1, + pad_h=1, + pad_w=1, + direction="bwd_data", + ) + dy3 = np.random.uniform(-0.5, 0.5, prob3.output_shape()).astype(np_dtype) + w3 = np.random.uniform(-0.5, 0.5, prob3.weight_shape()).astype(np_dtype) + res3 = runners[("bwd_data", 3)].run(dy3, w3, prob3) + nz = np.count_nonzero(res3.output) + ok_3d = res3.success and nz > 0 + print(f" Time: {res3.time_ms:.4f} ms, NonZero: {nz}/{res3.output.size}") + + for r in runners.values(): + r.cleanup() + + passed = res.success and match_2d and ok_3d + print("\n" + "=" * 70) + print(f" BwdData 2D: {'PASS' if match_2d else 'FAIL'} (CPU validated)") + print(f" BwdData 3D: {'PASS' if ok_3d else 'FAIL'}") + print(f" Status: {'PASS' if passed else 'FAIL'}") + print("=" * 70) + return 0 if passed else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/grouped_conv/python/04_bwd_weight.py b/dispatcher/examples/grouped_conv/python/04_bwd_weight.py new file mode 100644 index 0000000000..48e50cd4a9 --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/04_bwd_weight.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 04: Backward Weight Convolution (2D + 3D) + +dW = ConvBwdWeight(X, dY) + +Declares backward-weight kernels with explicit parameters, +builds a registry, JIT compiles, runs on GPU, and validates +against a CPU reference. + +Usage: + python3 04_bwd_weight.py +""" + +import sys +import argparse +import time +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + detect_gpu_arch, +) + + +def cpu_conv2d_bwd_weight(x, dy, prob): + """CPU ref: compute dW from X and dY.""" + N, Hi, Wi, G, C = x.shape + _, Ho, Wo, _, Kpg = dy.shape + Y, X_ = prob.Y, prob.X + dw = np.zeros((G, Kpg, Y, X_, C), dtype=np.float32) + for g in range(G): + for k in range(Kpg): + for y in range(Y): + for xf in range(X_): + for c in range(C): + s = 0.0 + for n in range(N): + for ho in range(Ho): + for wo in range(Wo): + hi = ho * prob.stride_h - prob.pad_h + y + wi = wo * prob.stride_w - prob.pad_w + xf + if 0 <= hi < Hi and 0 <= wi < Wi: + s += float(x[n, hi, wi, g, c]) * float( + dy[n, ho, wo, g, k] + ) + dw[g, k, y, xf, c] = s + return dw + + +def main(): + parser = argparse.ArgumentParser(description="Backward Weight (2D + 3D)") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--workers", type=int, default=0) + parser.add_argument( + "--split-k", type=int, default=1, help="Split-K factor for bwd_weight (k_batch)" + ) + args = parser.parse_args() + + arch = args.arch + print("=" * 70) + print("Example 04: Backward Weight Convolution (2D + 3D)") + print("=" * 70) + print(f"\n Arch: {arch}, Dtype: {args.dtype}") + print(" dW = ConvBwdWeight(X, dY)") + + # ========================================================================= + # Step 1: Declare bwd_weight kernels + # ========================================================================= + print("\n--- Step 1: Declare BwdWeight Kernels ---") + reg = GroupedConvRegistry("bwd_weight_conv") + + # BwdWeight 2D: compv3, 128x128 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_weight", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # BwdWeight 3D: compv3, 64x64 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_weight", + ndim_spatial=3, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + reg.print_registry() + + # ========================================================================= + # Step 2: JIT build + # ========================================================================= + print("\n--- Step 2: JIT Build ---") + workers = args.workers if args.workers > 0 else None + t0 = time.perf_counter() + runners = reg.build(verbose=False, max_workers=workers) + jit_s = time.perf_counter() - t0 + print(f" Built {len(runners)} runners in {jit_s:.1f}s") + + if ("bwd_weight", 2) not in runners: + print(" ERROR: bwd_weight 2D JIT failed") + return 1 + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # ========================================================================= + # Step 3: BwdWeight 2D -- GPU + CPU reference + # ========================================================================= + print("\n--- Step 3: Backward Weight 2D ---") + prob = GroupedConvProblem( + N=1, + C=32, + K=32, + Hi=8, + Wi=8, + Y=3, + X=3, + pad_h=1, + pad_w=1, + direction="bwd_weight", + split_k=args.split_k, + ) + prob.print_problem() + + x = np.random.uniform(-0.5, 0.5, prob.input_shape()).astype(np_dtype) + dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(np_dtype) + + res = runners[("bwd_weight", 2)].run(x, dy, prob) + print(f" Time: {res.time_ms:.4f} ms") + print(f" TFLOPS: {res.tflops:.2f}") + print(f" NonZero: {np.count_nonzero(res.output)}/{res.output.size}") + + ref = cpu_conv2d_bwd_weight(x, dy, prob) + diff = np.abs(res.output.astype(np.float32) - ref) + match_2d = np.allclose(res.output.astype(np.float32), ref, atol=0.5) + print(f" CPU ref: max_abs={diff.max():.6f}, match={match_2d}") + + # ========================================================================= + # Step 4: BwdWeight 3D -- GPU + non-zero check + # ========================================================================= + ok_3d = True + if ("bwd_weight", 3) in runners: + print("\n--- Step 4: Backward Weight 3D ---") + prob3 = GroupedConvProblem( + N=1, + C=32, + K=32, + Di=6, + Hi=6, + Wi=6, + Z=3, + Y=3, + X=3, + pad_d=1, + pad_h=1, + pad_w=1, + direction="bwd_weight", + ) + x3 = np.random.uniform(-0.5, 0.5, prob3.input_shape()).astype(np_dtype) + dy3 = np.random.uniform(-0.5, 0.5, prob3.output_shape()).astype(np_dtype) + res3 = runners[("bwd_weight", 3)].run(x3, dy3, prob3) + nz = np.count_nonzero(res3.output) + ok_3d = res3.success and nz > 0 + print(f" Time: {res3.time_ms:.4f} ms, NonZero: {nz}/{res3.output.size}") + + for r in runners.values(): + r.cleanup() + + passed = res.success and match_2d and ok_3d + print("\n" + "=" * 70) + print(f" BwdWeight 2D: {'PASS' if match_2d else 'FAIL'} (CPU validated)") + print(f" BwdWeight 3D: {'PASS' if ok_3d else 'FAIL'}") + print(f" Status: {'PASS' if passed else 'FAIL'}") + print("=" * 70) + return 0 if passed else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/grouped_conv/python/05_benchmark.py b/dispatcher/examples/grouped_conv/python/05_benchmark.py new file mode 100644 index 0000000000..9166ab988e --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/05_benchmark.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 05: Multi-Problem GPU Benchmark + +Declares kernels with explicit tile/wave/warp/pipeline parameters for +all directions, builds registries, JIT compiles, and benchmarks across +ResNet-like problem sizes with configurable warmup/repeat. + +Usage: + python3 05_benchmark.py + python3 05_benchmark.py --warmup 3 --repeat 10 + python3 05_benchmark.py --workers 4 +""" + +import sys +import argparse +import time +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + detect_gpu_arch, +) + + +def compute_bytes(prob, dtype_bytes=2): + in_elems = 1 + for d in prob.input_shape(): + in_elems *= d + wei_elems = 1 + for d in prob.weight_shape(): + wei_elems *= d + out_elems = 1 + for d in prob.output_shape(): + out_elems *= d + return (in_elems + wei_elems + out_elems) * dtype_bytes + + +def main(): + parser = argparse.ArgumentParser(description="Multi-Problem GPU Benchmark") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations") + parser.add_argument("--repeat", type=int, default=5, help="Benchmark iterations") + parser.add_argument( + "--workers", type=int, default=0, help="Max JIT workers (0=auto)" + ) + args = parser.parse_args() + + print("=" * 70) + print("Example 05: Multi-Problem GPU Benchmark") + print("=" * 70) + print(f"\n Arch: {args.arch}, Dtype: {args.dtype}") + print(f" Warmup: {args.warmup}, Repeat: {args.repeat}") + + # ========================================================================= + # Step 1: Declare all kernels with explicit parameters + # ========================================================================= + print("\n--- Step 1: Declare Kernels ---") + reg = GroupedConvRegistry("benchmark") + + # Forward 2D: compv4, 128x128 tile + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # Forward 3D: compv3, 64x64 tile + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=3, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # BwdData 2D: compv3, 128x128 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_data", + ndim_spatial=2, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # BwdWeight 2D: compv3, 128x128 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_weight", + ndim_spatial=2, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + reg.print_registry() + + # ========================================================================= + # Step 2: JIT build + # ========================================================================= + print("\n--- Step 2: JIT Build ---") + workers = args.workers if args.workers > 0 else None + t0 = time.perf_counter() + runner_by_key = reg.build(verbose=False, max_workers=workers) + jit_s = time.perf_counter() - t0 + + for key in [("forward", 2), ("forward", 3), ("bwd_data", 2), ("bwd_weight", 2)]: + tag = "OK" if key in runner_by_key else "FAILED" + print(f" {key[0]:12s} {key[1]}D: {tag}") + print(f" JIT build time: {jit_s:.3f} s") + + missing = [ + k + for k in [("forward", 2), ("forward", 3), ("bwd_data", 2), ("bwd_weight", 2)] + if k not in runner_by_key + ] + if missing: + print(f"\n ERROR: missing {missing}") + return 1 + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + def bench_run(runner, inp, wei, prob): + for _ in range(args.warmup): + runner.run(inp, wei, prob) + times = [] + for _ in range(args.repeat): + r = runner.run(inp, wei, prob) + if r.success: + times.append(r.time_ms) + if not times: + return 0.0, 0.0 + return min(times), sum(times) / len(times) + + # ========================================================================= + # Step 3: 2D Forward benchmark + # ========================================================================= + print("\n--- Step 3: Forward 2D Benchmark ---") + print( + f"{'Problem':<18} {'N':>3} {'C':>4} {'K':>4} {'H':>3} {'W':>3} " + f"{'F':>3} {'Min(ms)':>9} {'Avg(ms)':>9} {'TFLOPS':>8} {'GB/s':>8}" + ) + print("-" * 85) + + all_ok = True + for label, n, c, k, h, w, y, x, s, p in [ + ("ResNet-stage2", 1, 64, 64, 56, 56, 3, 3, 1, 1), + ("ResNet-stage3", 1, 128, 128, 28, 28, 3, 3, 1, 1), + ("ResNet-stage4", 1, 256, 256, 14, 14, 3, 3, 1, 1), + ("ResNet-stage5", 1, 512, 512, 7, 7, 3, 3, 1, 1), + ("Pointwise-1x1", 1, 256, 256, 56, 56, 1, 1, 1, 0), + ("Batch-8", 8, 64, 128, 56, 56, 3, 3, 1, 1), + ("Batch-32", 32, 64, 128, 56, 56, 3, 3, 1, 1), + ]: + prob = GroupedConvProblem( + N=n, + C=c, + K=k, + Hi=h, + Wi=w, + Y=y, + X=x, + stride_h=s, + stride_w=s, + pad_h=p, + pad_w=p, + direction="forward", + ) + inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype) + wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype) + min_ms, avg_ms = bench_run(runner_by_key[("forward", 2)], inp, wei, prob) + if avg_ms > 0: + tflops = prob.flops / (avg_ms * 1e9) + bw = compute_bytes(prob) / (avg_ms * 1e6) + print( + f"{label:<18} {n:>3} {c:>4} {k:>4} {h:>3} {w:>3} " + f"{y}x{x} {min_ms:>9.4f} {avg_ms:>9.4f} {tflops:>8.2f} {bw:>8.1f}" + ) + else: + all_ok = False + + # ========================================================================= + # Step 4: 3D Forward + # ========================================================================= + print("\n--- Step 4: Forward 3D ---") + for label, n, c, k, d, h, w, z, y, x in [ + ("3D-small", 1, 64, 64, 8, 16, 16, 3, 3, 3), + ("3D-medium", 1, 64, 128, 16, 32, 32, 3, 3, 3), + ]: + prob = GroupedConvProblem( + N=n, C=c, K=k, Di=d, Hi=h, Wi=w, Z=z, Y=y, X=x, direction="forward" + ) + inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype) + wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype) + min_ms, avg_ms = bench_run(runner_by_key[("forward", 3)], inp, wei, prob) + if avg_ms > 0: + tflops = prob.flops / (avg_ms * 1e9) + print(f" {label:<14} {min_ms:.4f} / {avg_ms:.4f} ms {tflops:.2f} TFLOPS") + + # ========================================================================= + # Step 5: Backward directions + # ========================================================================= + print("\n--- Step 5: Backward Directions ---") + for label, direction in [ + ("bwd_data ResNet-s3", "bwd_data"), + ("bwd_weight ResNet-s3", "bwd_weight"), + ]: + prob = GroupedConvProblem( + N=1, + C=128, + K=128, + Hi=28, + Wi=28, + Y=3, + X=3, + stride_h=1, + stride_w=1, + pad_h=1, + pad_w=1, + direction=direction, + ) + inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype) + wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype) + min_ms, avg_ms = bench_run(runner_by_key[(direction, 2)], inp, wei, prob) + if avg_ms > 0: + tflops = prob.flops / (avg_ms * 1e9) + print( + f" {label:<14} {direction:>12} {min_ms:.4f} / {avg_ms:.4f} ms {tflops:.2f} TFLOPS" + ) + + for runner in runner_by_key.values(): + runner.cleanup() + + print("\n" + "=" * 70) + print(f" JIT build: {jit_s:.3f} s") + print(f" Warmup: {args.warmup}, Repeat: {args.repeat}") + print(f" Status: {'PASS' if all_ok else 'FAIL'}") + print("=" * 70) + return 0 if all_ok else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/grouped_conv/python/06_registry_json.py b/dispatcher/examples/grouped_conv/python/06_registry_json.py new file mode 100644 index 0000000000..1a3dc854e7 --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/06_registry_json.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 06: Registry, Heuristic Selection & JSON Export + +Declares multiple kernel configurations with different tile sizes, +builds a registry, demonstrates heuristic runtime kernel selection, +JSON round-trip, and GPU execution. + +Usage: + python3 06_registry_json.py + python3 06_registry_json.py --workers 4 +""" + +import sys +import time +import argparse +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + detect_gpu_arch, +) + + +def conv_heuristic(problem): + spatial = problem.Ho * problem.Wo + if spatial > 400: + return ["256", "128", "64"] + return ["64", "128", "256"] + + +def main(): + parser = argparse.ArgumentParser(description="Registry, Heuristic & JSON") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--workers", type=int, default=0) + args = parser.parse_args() + + arch = args.arch + print("=" * 70) + print("Example 06: Registry, Heuristic Selection & JSON Export") + print("=" * 70) + print(f"\n Arch: {arch}, Dtype: {args.dtype}") + + # Step 1: Declare kernels with full explicit parameters + print("\n--- Step 1: Declare Kernels + Build Registry ---") + reg = GroupedConvRegistry("conv_tiles") + + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=256, + tile_k=256, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + num_wave_groups=1, + num_groups_to_merge=1, + ) + ) + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + num_wave_groups=1, + num_groups_to_merge=1, + ) + ) + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + num_wave_groups=1, + num_groups_to_merge=1, + ) + ) + reg.print_registry() + + # Step 2: Heuristic kernel selection + print("\n--- Step 2: Heuristic Kernel Selection ---") + problems = [ + ( + "small_7x7", + GroupedConvProblem( + N=1, + C=512, + K=512, + Hi=7, + Wi=7, + Y=3, + X=3, + pad_h=1, + pad_w=1, + direction="forward", + ), + ), + ( + "medium_14x14", + GroupedConvProblem( + N=1, + C=256, + K=256, + Hi=14, + Wi=14, + Y=3, + X=3, + pad_h=1, + pad_w=1, + direction="forward", + ), + ), + ( + "large_56x56", + GroupedConvProblem( + N=1, + C=64, + K=128, + Hi=56, + Wi=56, + Y=3, + X=3, + pad_h=1, + pad_w=1, + direction="forward", + ), + ), + ] + print(f" {'Problem':<16} {'Spatial':>8} {'Selected Kernel':<50}") + print(f" {'-' * 74}") + for label, prob in problems: + selected = reg.select(prob, heuristic=conv_heuristic) + spatial = prob.Ho * prob.Wo + sel_name = selected.name if selected else "none" + print(f" {label:<16} {spatial:>8} {sel_name:<50}") + + # Step 3: JSON round-trip + print("\n--- Step 3: JSON Round-Trip ---") + json_str = reg.to_json() + print(f" Exported: {len(json_str)} bytes, {len(reg)} kernels") + imported = GroupedConvRegistry.from_json(json_str) + print(f" Imported: {len(imported)} kernels") + orig = reg.kernels[0] + imp = imported.kernels[0] + rt_ok = ( + orig.vector_size_a == imp.vector_size_a + and orig.block_per_cu == imp.block_per_cu + and orig.tile_n == imp.tile_n + ) + print(f" Full fields round-trip: {'OK' if rt_ok else 'FAIL'}") + + # Step 4: JIT build + GPU execution + print("\n--- Step 4: JIT Build + GPU Execution ---") + workers = args.workers if args.workers > 0 else None + jit_reg = GroupedConvRegistry("jit_conv") + jit_reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + ) + ) + t0 = time.perf_counter() + runners = jit_reg.build(verbose=False, max_workers=workers) + jit_s = time.perf_counter() - t0 + + if ("forward", 2) not in runners: + print(" JIT build failed") + return 1 + runner = runners[("forward", 2)] + print(f" JIT build: {jit_s:.3f} s") + print(f" Library: {runner.library_path}") + + prob = GroupedConvProblem( + N=1, C=128, K=128, Hi=16, Wi=16, Y=3, X=3, pad_h=1, pad_w=1, direction="forward" + ) + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype) + wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype) + res = runner.run(inp, wei, prob) + runner.cleanup() + + if res.success: + print(f" Time: {res.time_ms:.4f} ms") + print(f" TFLOPS: {res.tflops:.2f}") + print(f" NonZero: {np.count_nonzero(res.output)}/{res.output.size}") + + gpu_ok = res.success + print("\n" + "=" * 70) + print(f" Registry: {len(reg)} kernels (3 tile configs)") + print(" Heuristic: spatial-based selection demonstrated") + print(f" JSON: round-trip {'OK' if rt_ok else 'FAIL'}") + print(f" GPU: {'OK' if gpu_ok else 'FAIL'}") + print(f" Status: {'PASS' if gpu_ok and rt_ok else 'FAIL'}") + print("=" * 70) + return 0 if gpu_ok and rt_ok else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/include/ck_tile/dispatcher.hpp b/dispatcher/include/ck_tile/dispatcher.hpp index 98d8bb9333..b3d8f10675 100644 --- a/dispatcher/include/ck_tile/dispatcher.hpp +++ b/dispatcher/include/ck_tile/dispatcher.hpp @@ -3,9 +3,17 @@ #pragma once -/// Main dispatcher header - includes all core components -/// Use this for convenient access to the full dispatcher API +/// Full dispatcher header - includes ALL operation types. +/// For minimal includes, use the per-operation headers instead: +/// ck_tile/dispatcher_gemm.hpp -- GEMM only +/// ck_tile/dispatcher_conv.hpp -- Grouped Convolution only +// Core (needed by all ops) +#include "ck_tile/dispatcher/base_registry.hpp" +#include "ck_tile/dispatcher/dispatcher_error.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +// GEMM #include "ck_tile/dispatcher/kernel_key.hpp" #include "ck_tile/dispatcher/kernel_config.hpp" #include "ck_tile/dispatcher/kernel_decl.hpp" @@ -13,7 +21,15 @@ #include "ck_tile/dispatcher/kernel_instance.hpp" #include "ck_tile/dispatcher/registry.hpp" #include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/json_export.hpp" #include "ck_tile/dispatcher/arch_filter.hpp" #include "ck_tile/dispatcher/backends/tile_backend.hpp" #include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" #include "ck_tile/dispatcher/utils.hpp" + +// Grouped Convolution +#include "ck_tile/dispatcher/grouped_conv_config.hpp" +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp" +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" diff --git a/dispatcher/include/ck_tile/dispatcher/README.md b/dispatcher/include/ck_tile/dispatcher/README.md index db3ce996a9..430798aedd 100644 --- a/dispatcher/include/ck_tile/dispatcher/README.md +++ b/dispatcher/include/ck_tile/dispatcher/README.md @@ -1,6 +1,6 @@ # CK Tile Dispatcher - C++ Headers -C++ API for the CK Tile dispatcher. +C++ API for the CK Tile dispatcher (GEMM and Grouped Convolution). > **See also:** [Main Dispatcher README](../../../../README.md) for installation and core concepts. @@ -8,16 +8,25 @@ C++ API for the CK Tile dispatcher. ``` dispatcher/ -├── dispatcher.hpp # Main dispatcher (kernel selection) -├── registry.hpp # Kernel registry (storage & lookup) -├── problem.hpp # Problem specification -├── kernel_key.hpp # Kernel configuration key -├── kernel_instance.hpp # Kernel instance interface -├── utils.hpp # Utilities (timers, GPU buffers) -│ -└── backends/ # Backend implementations - ├── generated_tile_backend.hpp # CK Tile kernels (production) - └── tile_backend.hpp # Tile backend base +|---- dispatcher.hpp # Main include (includes all below) +| +|---- # GEMM Headers +|---- registry.hpp # Kernel registry (storage & lookup) +|---- problem.hpp # GEMM problem specification +|---- kernel_key.hpp # Kernel configuration key +|---- kernel_instance.hpp # Kernel instance interface +|---- utils.hpp # Utilities (timers, GPU buffers) +| +|---- # Grouped Convolution Headers +|---- grouped_conv_config.hpp # GroupedConvDirection, GroupedConvConfig +|---- grouped_conv_problem.hpp # GroupedConvProblem + ProblemBuilder +|---- grouped_conv_kernel_decl.hpp # GroupedConvKernelDecl, DECL_GROUPED_CONV_KERNEL_SET +|---- grouped_conv_registry.hpp # Thread-safe registry with JSON export & filtering +|---- grouped_conv_utils.hpp # Config creators, validation, benchmark utilities +| ++---- backends/ # Backend implementations + |---- generated_tile_backend.hpp # CK Tile kernels (production) + +---- tile_backend.hpp # Tile backend base ``` ## Quick Start @@ -148,6 +157,69 @@ auto kernel = create_generated_tile_kernel< >(key, name); ``` +## Grouped Convolution API + +### GroupedConvProblem (`grouped_conv_problem.hpp`) + +Problem specification with builder pattern: + +```cpp +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" + +using namespace ck_tile::dispatcher; + +auto problem = GroupedConvProblemBuilder() + .n(2).g(1).c(128).k(256) + .input_spatial({28, 28}) + .filter_spatial({3, 3}) + .strides({1, 1}) + .dilations({1, 1}) + .left_pads({1, 1}) + .right_pads({1, 1}) + .build(); + +bool ok = problem.is_valid(); +``` + +### GroupedConvRegistry (`grouped_conv_registry.hpp`) + +Thread-safe registry with JSON export and filtering: + +```cpp +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" + +auto& registry = GroupedConvRegistry::instance(); + +// Thread-safe registration +registry.register_kernel(kernel); + +// JSON export +std::string json = registry.export_json(); +registry.export_json_to_file("kernels.json"); + +// Filtering +auto gfx942_kernels = registry.filter_by_arch("gfx942"); +auto matched = registry.filter([](const auto& k) { return k.is_fwd(); }); +``` + +### DECL_GROUPED_CONV_KERNEL_SET (`grouped_conv_kernel_decl.hpp`) + +Declarative kernel definition: + +```cpp +DECL_GROUPED_CONV_KERNEL_SET(my_conv_kernels, + .add( + GroupedConvSignature().dtype("fp16").layout("nhwgc"), + GroupedConvAlgorithm().tile(128, 128, 32).wave(2, 2, 1) + .warp(32, 32, 16).pipeline("compv4"), + "gfx942" + ) +); + +// Register all matching current arch +DECL_GROUPED_CONV_KERNEL_ALL(all_conv_kernels, "gfx942"); +``` + ## Best Practices 1. Use `Release` build for performance @@ -155,6 +227,8 @@ auto kernel = create_generated_tile_kernel< 3. Use `Priority::High` for hand-tuned kernels 4. Reuse dispatcher instances 5. Clear registry between test runs +6. Use `GroupedConvProblemBuilder` for validated problem construction +7. Leverage `export_json()` for kernel inventory and debugging --- diff --git a/dispatcher/include/ck_tile/dispatcher/backends/generated_conv_backend.hpp b/dispatcher/include/ck_tile/dispatcher/backends/generated_conv_backend.hpp new file mode 100644 index 0000000000..04ee1b2d11 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/backends/generated_conv_backend.hpp @@ -0,0 +1,152 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Generated Convolution Kernel Backend +// +// Wraps CK Tile grouped convolution launchers for use through the +// GroupedConvDispatcher. Each generated kernel launcher is wrapped in +// a ConvKernelRunFn that builds the correct host-args type (forward, +// bwd-data, or bwd-weight) and calls Launcher::launch(). + +#pragma once + +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace backends { + +// Buffer context is defined in grouped_conv_registry.hpp (g_conv_dispatch_buffers) +// so there's no circular dependency. + +// Helper: build ck_tile::conv::ConvParam from GroupedConvProblem +inline ck_tile::conv::ConvParam make_conv_param_2d(const GroupedConvProblem& p) +{ + return ck_tile::conv::ConvParam{ + 2, + static_cast(p.G), + static_cast(p.N), + static_cast(p.K), + static_cast(p.C), + {static_cast(p.filter_spatial[1]), + static_cast(p.filter_spatial[2])}, + {static_cast(p.input_spatial[1]), + static_cast(p.input_spatial[2])}, + {static_cast(p.stride[1]), static_cast(p.stride[2])}, + {static_cast(p.dilation[1]), + static_cast(p.dilation[2])}, + {static_cast(p.padding[1]), static_cast(p.padding[2])}, + {static_cast(p.padding[1]), static_cast(p.padding[2])}}; +} + +inline ck_tile::conv::ConvParam make_conv_param_3d(const GroupedConvProblem& p) +{ + return ck_tile::conv::ConvParam{3, + static_cast(p.G), + static_cast(p.N), + static_cast(p.K), + static_cast(p.C), + {static_cast(p.filter_spatial[0]), + static_cast(p.filter_spatial[1]), + static_cast(p.filter_spatial[2])}, + {static_cast(p.input_spatial[0]), + static_cast(p.input_spatial[1]), + static_cast(p.input_spatial[2])}, + {static_cast(p.stride[0]), + static_cast(p.stride[1]), + static_cast(p.stride[2])}, + {static_cast(p.dilation[0]), + static_cast(p.dilation[1]), + static_cast(p.dilation[2])}, + {static_cast(p.padding[0]), + static_cast(p.padding[1]), + static_cast(p.padding[2])}, + {static_cast(p.padding[0]), + static_cast(p.padding[1]), + static_cast(p.padding[2])}}; +} + +// Create a RunFn for a forward convolution launcher (2D or 3D) +template +inline GroupedConvKernelInstance::RunFn make_conv_fwd_run_fn() +{ + return [](const GroupedConvProblem& problem, void* stream) -> float { + auto& ctx = g_conv_dispatch_buffers; + auto param = (NDim == 2) ? make_conv_param_2d(problem) : make_conv_param_3d(problem); + ck_tile::GroupedConvFwdHostArgs<> args( + param, ctx.input_ptr, ctx.weight_ptr, {}, ctx.output_ptr, 1); + ck_tile::stream_config sc; + sc.stream_id_ = reinterpret_cast(stream); + sc.time_kernel_ = ctx.benchmarking; + sc.log_level_ = 0; + sc.cold_niters_ = ctx.benchmarking ? ctx.warmup : 0; + sc.nrepeat_ = ctx.benchmarking ? ctx.repeat : 1; + sc.is_gpu_timer_ = ctx.benchmarking; + return LauncherType::launch(args, sc); + }; +} + +// Create a RunFn for a backward-data convolution launcher. +// Dispatcher convention: run(dY, W, dX, problem) where dX is computed. +// BwdDataHostArgs(param, in_ptr=dX, wei_ptr=W, {}, out_ptr=dY, k_batch) +template +inline GroupedConvKernelInstance::RunFn make_conv_bwd_data_run_fn() +{ + return [](const GroupedConvProblem& problem, void* stream) -> float { + auto& ctx = g_conv_dispatch_buffers; + auto param = (NDim == 2) ? make_conv_param_2d(problem) : make_conv_param_3d(problem); + ck_tile::GroupedConvBwdDataHostArgs args( + param, + ctx.output_ptr, // in_ptr = dX (being computed) + ctx.weight_ptr, // wei_ptr = W + {}, + ctx.input_ptr, // out_ptr = dY (gradient from next layer) + 1); + ck_tile::stream_config sc; + sc.stream_id_ = reinterpret_cast(stream); + sc.time_kernel_ = ctx.benchmarking; + sc.log_level_ = 0; + sc.cold_niters_ = ctx.benchmarking ? ctx.warmup : 0; + sc.nrepeat_ = ctx.benchmarking ? ctx.repeat : 1; + sc.is_gpu_timer_ = ctx.benchmarking; + return LauncherType::launch(args, sc); + }; +} + +// Create a RunFn for a backward-weight convolution launcher. +// Dispatcher convention: run(X, dY, dW, problem) where dW is computed. +// BwdWeightHostArgs(param, in_ptr=X, wei_ptr=dW, {}, out_ptr=dY, k_batch) +template +inline GroupedConvKernelInstance::RunFn make_conv_bwd_weight_run_fn() +{ + return [](const GroupedConvProblem& problem, void* stream) -> float { + auto& ctx = g_conv_dispatch_buffers; + auto param = (NDim == 2) ? make_conv_param_2d(problem) : make_conv_param_3d(problem); + const int k_batch = (ctx.split_k > 1) ? ctx.split_k : 1; + ck_tile::GroupedConvBwdWeightHostArgs args(param, + ctx.input_ptr, // in_ptr = X + ctx.output_ptr, // wei_ptr = dW (being computed) + {}, + ctx.weight_ptr, // out_ptr = dY + k_batch); + ck_tile::stream_config sc; + sc.stream_id_ = reinterpret_cast(stream); + sc.time_kernel_ = ctx.benchmarking; + sc.log_level_ = 0; + sc.cold_niters_ = ctx.benchmarking ? ctx.warmup : 0; + sc.nrepeat_ = ctx.benchmarking ? ctx.repeat : 1; + sc.is_gpu_timer_ = ctx.benchmarking; + return LauncherType::launch(args, sc); + }; +} + +} // namespace backends +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/base_registry.hpp b/dispatcher/include/ck_tile/dispatcher/base_registry.hpp new file mode 100644 index 0000000000..2bb940c320 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/base_registry.hpp @@ -0,0 +1,199 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Shared priority enum used by all registry types +enum class Priority +{ + Low = 0, + Normal = 1, + High = 2 +}; + +/// BaseRegistry: Thread-safe, priority-aware kernel storage shared by GEMM and Conv registries. +/// +/// Template Parameters: +/// Derived - CRTP derived class (e.g., Registry, ConvRegistry) +/// KeyType - primary key type (std::string for GEMM, ConvKernelKey for Conv) +/// InstanceType - kernel instance type (KernelInstance, ConvKernelInstance) +/// KeyHash - hash functor for KeyType (defaults to std::hash) +template > +class BaseRegistry +{ + public: + using InstancePtr = std::shared_ptr; + + struct Entry + { + InstancePtr instance; + Priority priority; + }; + + BaseRegistry() = default; + virtual ~BaseRegistry() = default; + + BaseRegistry(BaseRegistry&& other) noexcept + { + std::lock_guard lock(other.mutex_); + entries_ = std::move(other.entries_); + name_ = std::move(other.name_); + } + + BaseRegistry& operator=(BaseRegistry&& other) noexcept + { + if(this != &other) + { + std::scoped_lock lock(mutex_, other.mutex_); + entries_ = std::move(other.entries_); + name_ = std::move(other.name_); + } + return *this; + } + + BaseRegistry(const BaseRegistry&) = delete; + BaseRegistry& operator=(const BaseRegistry&) = delete; + + /// Register a kernel. If the key already exists, the new entry replaces it + /// unless the existing entry has strictly higher priority. + /// Same-priority registration overwrites (last-writer-wins at equal priority). + bool + register_kernel(const KeyType& key, InstancePtr instance, Priority priority = Priority::Normal) + { + std::lock_guard lock(mutex_); + auto it = entries_.find(key); + if(it != entries_.end() && it->second.priority > priority) + { + return false; + } + entries_[key] = Entry{std::move(instance), priority}; + return true; + } + + [[nodiscard]] std::size_t size() const + { + std::lock_guard lock(mutex_); + return entries_.size(); + } + + [[nodiscard]] bool empty() const + { + std::lock_guard lock(mutex_); + return entries_.empty(); + } + + void clear() + { + std::lock_guard lock(mutex_); + entries_.clear(); + } + + [[nodiscard]] std::string get_name() const + { + std::lock_guard lock(mutex_); + return name_; // return by value to avoid dangling reference + } + + void set_name(const std::string& name) + { + std::lock_guard lock(mutex_); + name_ = name; + } + + [[nodiscard]] std::vector get_all_instances() const + { + std::lock_guard lock(mutex_); + std::vector result; + result.reserve(entries_.size()); + for(const auto& [key, entry] : entries_) + { + result.push_back(entry.instance); + } + return result; + } + + std::size_t merge_from(const BaseRegistry& other, Priority priority = Priority::Normal) + { + std::scoped_lock lock(mutex_, other.mutex_); + std::size_t merged = 0; + for(const auto& [key, entry] : other.entries_) + { + auto it = entries_.find(key); + if(it == entries_.end() || it->second.priority <= priority) + { + entries_[key] = Entry{entry.instance, priority}; + ++merged; + } + } + return merged; + } + + /// Enable automatic JSON export after every kernel registration. + /// Requires the derived class to implement export_json_to_file(path, stats). + void enable_auto_export(const std::string& path, + bool include_statistics = true, + bool export_on_every_registration = true) + { + std::lock_guard lock(mutex_); + auto_export_path_ = path; + auto_export_stats_ = include_statistics; + auto_export_on_register_ = export_on_every_registration; + auto_export_enabled_.store(true, std::memory_order_release); + } + + void disable_auto_export() { auto_export_enabled_.store(false, std::memory_order_release); } + + [[nodiscard]] bool is_auto_export_enabled() const + { + return auto_export_enabled_.load(std::memory_order_acquire); + } + + /// Call after registration to trigger auto-export if enabled. + void perform_auto_export() + { + if(!auto_export_enabled_.load(std::memory_order_acquire)) + return; + std::lock_guard lock(mutex_); + if(auto_export_on_register_) + { + static_cast(this)->export_json_to_file(auto_export_path_, auto_export_stats_); + } + } + + protected: + [[nodiscard]] const std::unordered_map& entries() const + { + return entries_; + } + + [[nodiscard]] std::unordered_map& entries_mut() { return entries_; } + + std::mutex& mutex() const { return mutex_; } + + private: + mutable std::mutex mutex_; + std::unordered_map entries_; + std::string name_ = "default"; + + std::atomic auto_export_enabled_{false}; + bool auto_export_on_register_ = true; + bool auto_export_stats_ = true; + std::string auto_export_path_; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp b/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp index 6d3f548138..d266d693da 100644 --- a/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp +++ b/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp @@ -23,6 +23,7 @@ #pragma once +#include "ck_tile/dispatcher/dispatcher_error.hpp" #include "ck_tile/dispatcher/kernel_instance.hpp" #include "ck_tile/dispatcher/problem.hpp" #include "ck_tile/dispatcher/registry.hpp" @@ -52,7 +53,11 @@ class Dispatcher /// Constructor /// @param registry Registry instance to use (default: global singleton) - explicit Dispatcher(Registry* registry = nullptr); + /// @param gfx_arch Target GPU architecture (e.g. "gfx950") + explicit Dispatcher(Registry* registry = nullptr, const std::string& gfx_arch = ""); + + void set_arch(const std::string& arch) { gfx_arch_ = arch; } + [[nodiscard]] const std::string& arch() const { return gfx_arch_; } /// Register a heuristic function for kernel selection /// @param heuristic Function that maps problems to ranked kernel identifiers @@ -74,7 +79,7 @@ class Dispatcher /// @param problem Problem configuration /// @param stream HIP stream for kernel launch (nullptr = default stream) /// @return Kernel execution time in milliseconds - /// @throws std::runtime_error if no suitable kernel found + /// @throws NoKernelFound if no suitable kernel found [[nodiscard]] float run(const void* a_ptr, const void* b_ptr, void* c_ptr, @@ -89,7 +94,7 @@ class Dispatcher /// @param problem Problem configuration /// @param stream HIP stream for kernel launch (nullptr = default stream) /// @return Kernel execution time in milliseconds - /// @throws std::runtime_error if no suitable kernel found + /// @throws NoKernelFound if no suitable kernel found [[nodiscard]] float run_fused(const void* a_ptr, const void* b_ptr, void* c_ptr, @@ -106,7 +111,8 @@ class Dispatcher /// @param problem Problem configuration /// @param stream HIP stream for kernel launch (nullptr = default stream) /// @return Kernel execution time in milliseconds - /// @throws std::runtime_error if kernel not found or doesn't support problem + /// @throws NoKernelFound if the kernel identifier is not registered + /// @throws UnsupportedProblem if the selected kernel does not support the problem [[nodiscard]] float run_explicit(const std::string& kernel_id, const void* a_ptr, const void* b_ptr, @@ -130,10 +136,18 @@ class Dispatcher const Problem& problem, float tolerance = 1e-3f) const; + /// Enable or disable GPU benchmarking (timing) on all kernels. + /// When disabled, kernels execute once with no timing overhead + /// (one-shot mode for production plugins). + void set_benchmarking(bool enable) { benchmarking_ = enable; } + [[nodiscard]] bool benchmarking_enabled() const { return benchmarking_; } + private: Registry* registry_; HeuristicFunction heuristic_; SelectionStrategy strategy_; + std::string gfx_arch_; + bool benchmarking_ = true; /// Select kernel using first-fit strategy [[nodiscard]] KernelInstancePtr select_first_fit(const Problem& problem) const; diff --git a/dispatcher/include/ck_tile/dispatcher/dispatcher_error.hpp b/dispatcher/include/ck_tile/dispatcher/dispatcher_error.hpp new file mode 100644 index 0000000000..98b079f8d9 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/dispatcher_error.hpp @@ -0,0 +1,28 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +namespace ck_tile { +namespace dispatcher { + +struct DispatcherError : std::runtime_error +{ + using std::runtime_error::runtime_error; +}; + +struct NoKernelFound : DispatcherError +{ + using DispatcherError::DispatcherError; +}; + +struct UnsupportedProblem : DispatcherError +{ + using DispatcherError::DispatcherError; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/dispatcher_log.hpp b/dispatcher/include/ck_tile/dispatcher/dispatcher_log.hpp new file mode 100644 index 0000000000..6a39766649 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/dispatcher_log.hpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Log levels for dispatcher transparency: +/// 0 = silent (default) +/// 1 = print selected kernel name +/// 2 = print all candidates considered and acceptance/rejection reasons +inline int get_log_level() +{ + static int level = []() { + const char* env = std::getenv("CK_DISPATCHER_LOG_LEVEL"); + return env ? std::atoi(env) : 0; + }(); + return level; +} + +inline void log_kernel_selected(const std::string& kernel_name, const std::string& problem_desc) +{ + if(get_log_level() >= 1) + { + std::cerr << "[CK Dispatcher] Selected kernel: " << kernel_name << " for " << problem_desc + << std::endl; + } +} + +inline void +log_kernel_candidate(const std::string& kernel_name, bool accepted, const std::string& reason) +{ + if(get_log_level() >= 2) + { + std::cerr << "[CK Dispatcher] Candidate: " << kernel_name << " -> " + << (accepted ? "ACCEPTED" : "REJECTED") + << (reason.empty() ? "" : " (" + reason + ")") << std::endl; + } +} + +inline void log_no_kernel_found(const std::string& problem_desc) +{ + if(get_log_level() >= 1) + { + std::cerr << "[CK Dispatcher] No kernel found for " << problem_desc << std::endl; + } +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/grouped_conv_config.hpp b/dispatcher/include/ck_tile/dispatcher/grouped_conv_config.hpp new file mode 100644 index 0000000000..91b7b3ad74 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/grouped_conv_config.hpp @@ -0,0 +1,588 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file grouped_conv_config.hpp + * @brief CK Tile Grouped Convolution Configuration with Builder-style naming + * + * This adopts the Signature/Algorithm/Arch pattern from: + * experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp + * + * Structure: + * - Signature: WHAT operation (types, layouts, direction, element ops) + * - Algorithm: HOW it's computed (tiles, warps, pipeline, scheduler, padding) + * - Arch: Target GPU architecture + */ + +#pragma once + +// Use common kernel_key types for DataType, Pipeline, etc. +#include "ck_tile/dispatcher/kernel_key.hpp" + +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +// DataType, Pipeline, Scheduler, Epilogue are defined in kernel_key.hpp +// No need to redefine them here + +// ============================================================================= +// Data Type Enum (matching CK Tile numeric types) +// ============================================================================= + +enum class ConvDataType +{ + // Standard floating point + FP32, // float + FP64, // double + FP16, // half_t + BF16, // bf16_t + + // 8-bit float variants (FP8/BF8) + FP8, // fp8_t (E4M3) + BF8, // bf8_t (E5M2) + FP8_E4M3, // Explicit E4M3 format + FP8_E5M2, // Explicit E5M2 format + + // Integer types + INT8, // int8_t + UINT8, // uint8_t + INT32, // int32_t (accumulator) + + // 4-bit types (gfx950+ only) + FP4, // MXFP4 + INT4 // pk_int4_t +}; + +// ============================================================================= +// Direction and Layout Enums +// ============================================================================= + +enum class GroupedConvDirection +{ + FORWARD, + BACKWARD_DATA, + BACKWARD_WEIGHT +}; + +enum class ConvLayout2D +{ + GNHWC_GKYXC_GNHWK, // NHWC-style + NHWGC_GKYXC_NHWGK, + NGCHW_GKYXC_NGKHW, // NCHW-style + NGCHW_GKCYX_NGKHW +}; + +enum class ConvLayout3D +{ + GNDHWC_GKZYXC_GNDHWK, + NDHWGC_GKZYXC_NDHWGK, + NGCDHW_GKZYXC_NGKDHW, + NGCDHW_GKCZYX_NGKDHW +}; + +// ============================================================================= +// Element-wise Operations +// ============================================================================= + +enum class ElementwiseOp +{ + PASS_THROUGH, + BIAS, + BIAS_CLAMP, + SCALE, + BILINEAR, + RELU, + GELU, + SIGMOID, + TANH +}; + +// ============================================================================= +// Grouped Convolution Specialization +// ============================================================================= + +enum class ConvSpecialization +{ + DEFAULT, + FILTER_1X1_PAD0, + FILTER_1X1_STRIDE1_PAD0, + FILTER_3X3, + FILTER_5X5, + FILTER_7X7 +}; + +// ============================================================================= +// Memory Operation Types (for accumulator operations) +// ============================================================================= + +enum class MemoryOperation +{ + SET, // Direct write (=) + ATOMIC_ADD, // Atomic addition (+=) + ATOMIC_MAX, // Atomic max + ADD // Non-atomic addition +}; + +// ============================================================================= +// Epilogue Types +// ============================================================================= + +enum class EpilogueType +{ + CSHUFFLE, // C-shuffle epilogue + DEFAULT_2D, // Default 2D epilogue + DEFAULT_GEMM_2D, // Default GEMM 2D epilogue + DIRECT_STORE, // Direct store without shuffle + BIAS_ADD, // Add bias + BIAS_ADD_RELU, // Add bias + ReLU + BIAS_ADD_GELU // Add bias + GELU +}; + +// ============================================================================= +// Algorithm Enums (matching builder/types.hpp and CK Tile pipelines) +// ============================================================================= + +enum class PipelineVersion +{ + V1, // Basic pipeline V1 + V2, // Basic pipeline V2 + V3, // Compute V3 (intrawave only) + V4, // Compute V4 (double buffer, ping-pong LDS) + V5, // Compute V5 (wave groups) + V6, // Compute V6 (newest) + MEMORY, // Memory pipeline + COMPUTE_ASYNC, // Compute with async copy + PRESHUFFLE_V2 // Preshuffle V2 pipeline +}; + +enum class PipelineScheduler +{ + DEFAULT, + INTRAWAVE, + INTERWAVE +}; + +enum class GemmPadding +{ + DEFAULT, + NO_PADDING, // No padding + M_PADDING, + N_PADDING, + K_PADDING, + MN_PADDING, + MK_PADDING, + NK_PADDING, + MNK_PADDING +}; + +// ============================================================================= +// Signature Info (WHAT operation) +// ============================================================================= + +struct GroupedConvSignatureInfo +{ + int spatial_dim = 2; // 1, 2, or 3 + GroupedConvDirection direction = GroupedConvDirection::FORWARD; + std::string in_type = "fp16"; + std::string wei_type = "fp16"; + std::string out_type = "fp16"; + std::string acc_type = "fp32"; + std::string workspace_type = "fp32"; // For two-stage algorithms + std::string bias_type = "fp16"; // For bias epilogue + ElementwiseOp in_element_op = ElementwiseOp::PASS_THROUGH; + ElementwiseOp wei_element_op = ElementwiseOp::PASS_THROUGH; + ElementwiseOp out_element_op = ElementwiseOp::PASS_THROUGH; + ConvSpecialization conv_spec = ConvSpecialization::DEFAULT; + int num_groups = 1; + + // String helpers + static const char* direction_str(GroupedConvDirection dir) + { + switch(dir) + { + case GroupedConvDirection::FORWARD: return "fwd"; + case GroupedConvDirection::BACKWARD_DATA: return "bwd_data"; + case GroupedConvDirection::BACKWARD_WEIGHT: return "bwd_weight"; + default: return "unknown"; + } + } + + static const char* datatype_str(ConvDataType dt) + { + switch(dt) + { + case ConvDataType::FP32: return "fp32"; + case ConvDataType::FP64: return "fp64"; + case ConvDataType::FP16: return "fp16"; + case ConvDataType::BF16: return "bf16"; + case ConvDataType::FP8: return "fp8"; + case ConvDataType::BF8: return "bf8"; + case ConvDataType::FP8_E4M3: return "fp8_e4m3"; + case ConvDataType::FP8_E5M2: return "fp8_e5m2"; + case ConvDataType::INT8: return "int8"; + case ConvDataType::UINT8: return "uint8"; + case ConvDataType::INT32: return "int32"; + case ConvDataType::FP4: return "fp4"; + case ConvDataType::INT4: return "int4"; + default: return "unknown"; + } + } +}; + +// ============================================================================= +// Algorithm Info (HOW it's computed) +// ============================================================================= + +struct DataTileInfo +{ + int m = 128; // M tile (output spatial * N) + int n = 128; // N tile (K output channels) + int k = 64; // K tile (C input channels) +}; + +struct WarpGemmParams +{ + int gemm_m = 16; // MFMA M dimension (MPerXDL) + int gemm_n = 16; // MFMA N dimension (NPerXDL) + int m_iter = 2; // M iterations per warp (MXdlPerWave) + int n_iter = 2; // N iterations per warp (NXdlPerWave) +}; + +struct BlockWarpConfig +{ + int m_warp = 2; // Warps along M + int n_warp = 2; // Warps along N + int k_warp = 1; // Warps along K + int m_warp_tile = 32; // Warp tile M + int n_warp_tile = 32; // Warp tile N + int k_warp_tile = 16; // Warp tile K +}; + +struct VectorSizeInfo +{ + int a = 4; // Input vector size + int b = 8; // Weight vector size + int c = 8; // Output vector size +}; + +struct GroupedConvAlgorithmInfo +{ + DataTileInfo tile; + BlockWarpConfig warp; + VectorSizeInfo vector_size; + + PipelineVersion pipeline = PipelineVersion::V4; + PipelineScheduler scheduler = PipelineScheduler::INTRAWAVE; + GemmPadding padding = GemmPadding::MNK_PADDING; + MemoryOperation memory_op = MemoryOperation::SET; + EpilogueType epilogue = EpilogueType::CSHUFFLE; + + int thread_block_size = 256; + bool double_smem_buffer = false; + int num_wave_groups = 1; + int block_per_cu = 1; + int num_groups_to_merge = 1; + + // Pipeline string + static const char* pipeline_str(PipelineVersion pv) + { + switch(pv) + { + case PipelineVersion::V1: return "v1"; + case PipelineVersion::V2: return "v2"; + case PipelineVersion::V3: return "compv3"; + case PipelineVersion::V4: return "compv4"; + case PipelineVersion::V5: return "compv5"; + case PipelineVersion::V6: return "compv6"; + case PipelineVersion::MEMORY: return "mem"; + case PipelineVersion::COMPUTE_ASYNC: return "comp_async"; + case PipelineVersion::PRESHUFFLE_V2: return "preshuffle_v2"; + default: return "unknown"; + } + } + + static const char* scheduler_str(PipelineScheduler ps) + { + switch(ps) + { + case PipelineScheduler::DEFAULT: return "default"; + case PipelineScheduler::INTRAWAVE: return "intrawave"; + case PipelineScheduler::INTERWAVE: return "interwave"; + default: return "unknown"; + } + } + + static const char* memory_op_str(MemoryOperation mo) + { + switch(mo) + { + case MemoryOperation::SET: return "set"; + case MemoryOperation::ATOMIC_ADD: return "atomic_add"; + case MemoryOperation::ATOMIC_MAX: return "atomic_max"; + case MemoryOperation::ADD: return "add"; + default: return "unknown"; + } + } + + static const char* epilogue_str(EpilogueType et) + { + switch(et) + { + case EpilogueType::CSHUFFLE: return "cshuffle"; + case EpilogueType::DEFAULT_2D: return "default_2d"; + case EpilogueType::DEFAULT_GEMM_2D: return "default_gemm_2d"; + case EpilogueType::DIRECT_STORE: return "direct_store"; + case EpilogueType::BIAS_ADD: return "bias_add"; + case EpilogueType::BIAS_ADD_RELU: return "bias_add_relu"; + case EpilogueType::BIAS_ADD_GELU: return "bias_add_gelu"; + default: return "unknown"; + } + } +}; + +// ============================================================================= +// Arch Info (Target GPU) +// ============================================================================= + +struct ArchInfo +{ + std::string name = "gfx942"; // MI300X default + int max_waves_per_cu = 8; + int lds_size_kb = 64; + int sgpr_count = 108; + int vgpr_count = 512; + + bool supports_mfma_fp16() const { return name.find("gfx9") != std::string::npos; } + bool supports_wmma() const { return name.find("gfx11") != std::string::npos; } +}; + +// ============================================================================= +// Full Grouped Conv Config (combines Signature + Algorithm + Arch) +// ============================================================================= + +struct GroupedConvConfig +{ + GroupedConvSignatureInfo signature; + GroupedConvAlgorithmInfo algorithm; + ArchInfo arch; + + // Generate unique kernel name + std::string name() const + { + std::ostringstream oss; + oss << "grouped_conv_" << GroupedConvSignatureInfo::direction_str(signature.direction) + << "_" << signature.in_type << "_" << signature.spatial_dim << "d" << "_" + << GroupedConvAlgorithmInfo::pipeline_str(algorithm.pipeline) << "_" << algorithm.tile.m + << "x" << algorithm.tile.n << "x" << algorithm.tile.k; + return oss.str(); + } + + // Brief description + std::string brief() const + { + std::ostringstream oss; + oss << signature.spatial_dim << "D " + << GroupedConvSignatureInfo::direction_str(signature.direction) + << " Grouped Convolution (" << signature.in_type << ")"; + return oss.str(); + } + + // Detailed description (tree-like) + std::string detailed() const + { + std::ostringstream oss; + oss << signature.spatial_dim << "D " + << GroupedConvSignatureInfo::direction_str(signature.direction) + << " Grouped Convolution Kernel\n"; + + oss << " Signature:\n"; + oss << " Data Type: " << signature.in_type << "\n"; + oss << " Accumulator: " << signature.acc_type << "\n"; + oss << " Groups: " << signature.num_groups << "\n"; + + oss << " Algorithm:\n"; + oss << " Thread Block Size: " << algorithm.thread_block_size << "\n"; + oss << " Data Tile: " << algorithm.tile.m << "x" << algorithm.tile.n << "x" + << algorithm.tile.k << "\n"; + oss << " Warp Config: " << algorithm.warp.m_warp << "x" << algorithm.warp.n_warp << "x" + << algorithm.warp.k_warp << "\n"; + oss << " Warp Tile: " << algorithm.warp.m_warp_tile << "x" << algorithm.warp.n_warp_tile + << "x" << algorithm.warp.k_warp_tile << "\n"; + oss << " Pipeline: " << GroupedConvAlgorithmInfo::pipeline_str(algorithm.pipeline) + << "\n"; + oss << " Scheduler: " << GroupedConvAlgorithmInfo::scheduler_str(algorithm.scheduler) + << "\n"; + + oss << " Arch:\n"; + oss << " Target: " << arch.name << "\n"; + + return oss.str(); + } +}; + +// ============================================================================= +// Predefined Configs +// ============================================================================= + +namespace configs { + +// Memory-bound config +template +struct Memory : public GroupedConvConfig +{ + Memory() + { + algorithm.tile = {128, 32, 128 / (int)sizeof(PrecType)}; + algorithm.warp = {4, 1, 1, 32, 32, 16}; + algorithm.pipeline = PipelineVersion::MEMORY; + algorithm.double_smem_buffer = false; + } +}; + +// Compute V3 - Small +template +struct CompV3_Small : public GroupedConvConfig +{ + CompV3_Small() + { + algorithm.tile = {16, 64, 64}; + algorithm.warp = {1, 4, 1, 16, 16, 32}; + algorithm.pipeline = PipelineVersion::V3; + } +}; + +// Compute V3 - Medium +template +struct CompV3_Medium : public GroupedConvConfig +{ + CompV3_Medium() + { + algorithm.tile = {128, 128, 128 / (int)sizeof(PrecType)}; + algorithm.warp = {2, 2, 1, 16, 16, 32}; + algorithm.pipeline = PipelineVersion::V3; + algorithm.block_per_cu = 2; + } +}; + +// Compute V3 - Large +template +struct CompV3_Large : public GroupedConvConfig +{ + CompV3_Large() + { + algorithm.tile = {256, 256, 128 / (int)sizeof(PrecType)}; + algorithm.warp = {2, 2, 1, 32, 32, 16}; + algorithm.pipeline = PipelineVersion::V3; + } +}; + +// Compute V4 - Double buffered +template +struct CompV4 : public GroupedConvConfig +{ + CompV4() + { + algorithm.tile = {256, 256, 64 / (int)sizeof(PrecType)}; + algorithm.warp = {2, 2, 1, 32, 32, 16}; + algorithm.pipeline = PipelineVersion::V4; + algorithm.double_smem_buffer = true; + } +}; + +// Compute V5 - Wave groups +template +struct CompV5 : public GroupedConvConfig +{ + CompV5() + { + algorithm.tile = {128, 128, 64 / (int)sizeof(PrecType)}; + algorithm.warp = {1, 1, 2, 32, 32, 16}; + algorithm.pipeline = PipelineVersion::V5; + algorithm.num_wave_groups = 2; + } +}; + +// WMMA config for gfx11xx +template +struct WMMA : public GroupedConvConfig +{ + WMMA() + { + algorithm.tile = {128, 128, 64 / (int)sizeof(PrecType)}; + algorithm.warp = {4, 2, 1, 16, 16, 16}; + algorithm.pipeline = PipelineVersion::V3; + algorithm.block_per_cu = 2; + arch.name = "gfx1100"; + } +}; + +// Merged groups config +template +struct CompV3_MergedGroups : public GroupedConvConfig +{ + CompV3_MergedGroups() + { + algorithm.tile = {16, 32, 32}; + algorithm.warp = {1, 2, 1, 16, 16, 32}; + algorithm.vector_size = {4, 8, 8}; + algorithm.pipeline = PipelineVersion::V3; + algorithm.num_groups_to_merge = 2; + } +}; + +} // namespace configs + +// ============================================================================= +// DataType Traits (compile-time type info for CK Tile types) +// ============================================================================= + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; + static constexpr int size_bytes = 4; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; + static constexpr int size_bytes = 8; +}; + +// Forward declare CK Tile types for traits +// Note: actual ck_tile types are defined in ck_tile/core/numeric/ +// These traits allow working with type names at compile time + +// ============================================================================= +// ConvTypeConfig (input/weight/acc/output type combinations) +// ============================================================================= + +template +struct ConvTypeConfig +{ + using input_type = InDataType; + using weight_type = WeiDataType; + using output_type = OutDataType; + using accumulator_type = AccDataType; +}; + +// Common type configurations as type aliases +// FP16 -> FP32 accumulator -> FP16 output (most common) +// BF16 -> FP32 accumulator -> BF16 output +// FP8 -> FP32 accumulator -> FP8 output +// INT8 -> INT32 accumulator -> INT8 output + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/grouped_conv_kernel_decl.hpp b/dispatcher/include/ck_tile/dispatcher/grouped_conv_kernel_decl.hpp new file mode 100644 index 0000000000..8ddfe445ff --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/grouped_conv_kernel_decl.hpp @@ -0,0 +1,537 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file grouped_conv_kernel_decl.hpp + * @brief Declarative grouped convolution kernel specification + * + * USAGE: + * ====== + * + * // Named kernel sets for grouped convolution + * DECL_GROUPED_CONV_KERNEL_SET(gconv_fwd, + * .add("fp16", "nhwc", "forward", 128, 128, 32) + * .add("fp16", "nhwc", "forward", 256, 256, 64) + * ); + * + * // Access at runtime + * auto& set = GroupedConvKernelSetRegistry::instance().get("gconv_fwd"); + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace grouped_conv_decl { + +// ============================================================================= +// Wildcard constants +// ============================================================================= + +constexpr const char* ANY = "*"; +constexpr int ANY_INT = -1; + +// ============================================================================= +// GroupedConvSignature - WHAT operation +// ============================================================================= + +class GroupedConvSignature +{ + public: + std::string dtype_in_ = "fp16"; // Input data type + std::string dtype_wei_ = "fp16"; // Weight data type + std::string dtype_out_ = "fp16"; // Output data type + std::string dtype_acc_ = "fp32"; // Accumulator type + std::string dtype_workspace_ = "fp32"; // Workspace type (two-stage algorithms) + std::string dtype_bias_ = "fp16"; // Bias type (bias epilogue) + std::string layout_ = "nhwc"; // Data layout: nhwc, nchw + std::string conv_op_ = "forward"; // forward, bwd_data, bwd_weight + int num_dims_ = 2; // Spatial dimensions: 1, 2, or 3 + int groups_ = 1; // Group grouped convolution + std::string specialization_ = "default"; // Filter specialization + + GroupedConvSignature& dtype(const std::string& in, + const std::string& wei, + const std::string& out, + const std::string& acc = "fp32") + { + dtype_in_ = in; + dtype_wei_ = wei; + dtype_out_ = out; + dtype_acc_ = acc; + return *this; + } + + GroupedConvSignature& dtype(const std::string& all) + { + dtype_in_ = dtype_wei_ = dtype_out_ = dtype_bias_ = all; + dtype_acc_ = dtype_workspace_ = "fp32"; + return *this; + } + + GroupedConvSignature& dtype_workspace(const std::string& ws) + { + dtype_workspace_ = ws; + return *this; + } + + GroupedConvSignature& dtype_bias(const std::string& b) + { + dtype_bias_ = b; + return *this; + } + + GroupedConvSignature& layout(const std::string& l) + { + layout_ = l; + return *this; + } + GroupedConvSignature& conv_type(const std::string& op) + { + conv_op_ = op; + return *this; + } + GroupedConvSignature& dims(int d) + { + num_dims_ = d; + return *this; + } + GroupedConvSignature& groups(int g) + { + groups_ = g; + return *this; + } + GroupedConvSignature& spec(const std::string& s) + { + specialization_ = s; + return *this; + } + + std::string op_str() const + { + if(conv_op_ == "forward") + return "fwd"; + if(conv_op_ == "bwd_data") + return "bwd_data"; + if(conv_op_ == "bwd_weight") + return "bwd_weight"; + return conv_op_; + } +}; + +// ============================================================================= +// GroupedConvAlgorithm - HOW it's implemented +// ============================================================================= + +class GroupedConvAlgorithm +{ + public: + // Tile shape (M, N, K per tile - M=spatial*N, N=K_out, K=C_in) + int tile_m_ = 1; // Tile M (output spatial * batch) + int tile_n_ = 128; // Tile N (output channels K) + int tile_k_ = 128; // Tile K (input channels C) + + // Output spatial tile + int tile_ho_ = 1; + int tile_wo_ = 16; + + // Wave/warp shape + int wave_m_ = ANY_INT; + int wave_n_ = ANY_INT; + int wave_k_ = 1; + int warp_m_ = ANY_INT; + int warp_n_ = ANY_INT; + int warp_k_ = 16; + + // Vector sizes + int vector_a_ = 4; // Input vector size + int vector_b_ = 8; // Weight vector size + int vector_c_ = 8; // Output vector size + + // Pipeline configuration + std::string pipeline_ = "compv4"; + std::string scheduler_ = "intrawave"; + std::string epilogue_ = "cshuffle"; + std::string memory_op_ = "set"; // Memory operation: set, atomic_add, atomic_max, add + + // Occupancy/performance hints + int block_size_ = 256; + int block_per_cu_ = 1; + int num_wave_groups_ = 1; + int num_groups_to_merge_ = 1; + bool double_smem_buffer_ = false; + + // Padding -- always enabled for convolution (MNK padding assumed) + static constexpr bool pad_m_ = true; + static constexpr bool pad_n_ = true; + static constexpr bool pad_k_ = true; + + // Tile setter (M, N, K) + GroupedConvAlgorithm& tile(int m, int n, int k) + { + tile_m_ = m; + tile_n_ = n; + tile_k_ = k; + return *this; + } + + GroupedConvAlgorithm& tile_output(int ho, int wo) + { + tile_ho_ = ho; + tile_wo_ = wo; + return *this; + } + + GroupedConvAlgorithm& wave(int m, int n, int k = 1) + { + wave_m_ = m; + wave_n_ = n; + wave_k_ = k; + return *this; + } + + GroupedConvAlgorithm& warp(int m, int n, int k = 16) + { + warp_m_ = m; + warp_n_ = n; + warp_k_ = k; + return *this; + } + + GroupedConvAlgorithm& vector_sizes(int a, int b, int c) + { + vector_a_ = a; + vector_b_ = b; + vector_c_ = c; + return *this; + } + + GroupedConvAlgorithm& pipeline(const std::string& p) + { + pipeline_ = p; + return *this; + } + GroupedConvAlgorithm& scheduler(const std::string& s) + { + scheduler_ = s; + return *this; + } + GroupedConvAlgorithm& epilogue(const std::string& e) + { + epilogue_ = e; + return *this; + } + GroupedConvAlgorithm& memory_op(const std::string& m) + { + memory_op_ = m; + return *this; + } + + // Occupancy setters + GroupedConvAlgorithm& block_per_cu(int b) + { + block_per_cu_ = b; + return *this; + } + GroupedConvAlgorithm& num_wave_groups(int n) + { + num_wave_groups_ = n; + return *this; + } + GroupedConvAlgorithm& num_groups_to_merge(int n) + { + num_groups_to_merge_ = n; + return *this; + } + GroupedConvAlgorithm& double_smem_buffer(bool d) + { + double_smem_buffer_ = d; + return *this; + } + + bool needs_expansion() const + { + return wave_m_ == ANY_INT || warp_m_ == ANY_INT || pipeline_ == "*" || scheduler_ == "*"; + } + + /// Check if specific parameter needs expansion + bool needs_wave_expansion() const { return wave_m_ == ANY_INT || wave_n_ == ANY_INT; } + bool needs_warp_expansion() const { return warp_m_ == ANY_INT || warp_n_ == ANY_INT; } + bool needs_pipeline_expansion() const { return pipeline_ == "*"; } + bool needs_scheduler_expansion() const { return scheduler_ == "*"; } + + /// Auto-fill with defaults (for single kernel generation) + void auto_fill() + { + if(wave_m_ == ANY_INT) + wave_m_ = 2; + if(wave_n_ == ANY_INT) + wave_n_ = 2; + if(warp_m_ == ANY_INT) + warp_m_ = 32; + if(warp_n_ == ANY_INT) + warp_n_ = 32; + if(pipeline_ == "*") + pipeline_ = "compv4"; + if(scheduler_ == "*") + scheduler_ = "intrawave"; + } + + /// Get all valid wave configurations for arch + static std::vector> valid_wave_configs(const std::string& arch) + { + // Match arch_specs_generated.py WARP_SUPPORTED_COMBINATIONS + if(arch == "gfx942" || arch == "gfx90a" || arch == "gfx950") + { + return {{1, 4, 1}, {2, 2, 1}, {4, 1, 1}}; + } + return {{2, 2, 1}}; // Default + } + + /// Get all valid warp tile configurations + static std::vector> valid_warp_configs(const std::string& arch, + const std::string& dtype) + { + // Match arch_specs_generated.py WARP_TILE_SUPPORTED_COMBINATIONS + if(arch == "gfx942" && (dtype == "fp16" || dtype == "bf16")) + { + return {{16, 16, 16}, {32, 32, 16}}; + } + return {{32, 32, 16}}; // Default + } + + /// Get all valid pipeline/scheduler combinations for forward conv. + /// Backward operations (bwd_data/bwd_weight) only support compv3 and mem + /// due to transpose_tile2d and get_length constraints in CK Tile. + static std::vector> valid_trait_configs() + { + return { + {"compv3", "intrawave"}, + {"compv4", "intrawave"}, + {"compv5", "intrawave"}, + {"mem", "intrawave"}, + {"mem", "interwave"}, + }; + } +}; + +// ============================================================================= +// GroupedConvKernelDecl +// ============================================================================= + +struct GroupedConvKernelDecl +{ + GroupedConvSignature signature; + GroupedConvAlgorithm algorithm; + std::string arch = "gfx942"; + + GroupedConvKernelDecl() = default; + + GroupedConvKernelDecl(const GroupedConvSignature& sig, + const GroupedConvAlgorithm& algo, + const std::string& a = "gfx942") + : signature(sig), algorithm(algo), arch(a) + { + } + + std::string name() const + { + std::ostringstream oss; + // Generate full kernel name similar to GEMM: + // grouped_conv____d______ + oss << "grouped_conv_" << signature.op_str() << "_" << signature.dtype_in_ << "_" + << signature.layout_ << "_" << signature.num_dims_ << "d" << "_" << algorithm.pipeline_ + << "_" << algorithm.epilogue_ << "_" << algorithm.scheduler_ << "_" << algorithm.tile_m_ + << "x" << algorithm.tile_n_ << "x" << algorithm.tile_k_ << "_" << algorithm.wave_m_ + << "x" << algorithm.wave_n_ << "x" << algorithm.wave_k_ << "_" << algorithm.warp_m_ + << "x" << algorithm.warp_n_ << "x" << algorithm.warp_k_; + return oss.str(); + } + + bool has_wildcards() const { return algorithm.needs_expansion() || arch == "*"; } +}; + +// ============================================================================= +// GroupedConvKernelSet +// ============================================================================= + +class GroupedConvKernelSet +{ + public: + GroupedConvKernelSet() = default; + + GroupedConvKernelSet& add(const GroupedConvSignature& sig, + const GroupedConvAlgorithm& algo, + const std::string& arch = "gfx942") + { + decls_.emplace_back(sig, algo, arch); + return *this; + } + + // Simple add: dtype, layout, conv_type, tile_k, tile_c + GroupedConvKernelSet& add(const std::string& dtype, + const std::string& layout, + const std::string& conv_type, + int tile_k, + int tile_c, + const std::string& arch = "gfx942") + { + GroupedConvSignature sig; + sig.dtype(dtype).layout(layout).conv_type(conv_type); + GroupedConvAlgorithm algo; + algo.tile(1, tile_k, tile_c); + decls_.emplace_back(sig, algo, arch); + return *this; + } + + GroupedConvKernelSet& merge(const GroupedConvKernelSet& other) + { + decls_.insert(decls_.end(), other.decls_.begin(), other.decls_.end()); + return *this; + } + + const std::vector& declarations() const { return decls_; } + size_t size() const { return decls_.size(); } + + void print(std::ostream& os = std::cout) const + { + os << "GroupedConvKernelSet (" << size() << " declarations):\n"; + for(const auto& d : decls_) + { + os << " - " << d.name(); + if(d.algorithm.needs_expansion()) + os << " [expands]"; + os << "\n"; + } + } + + GroupedConvKernelSet& tag(const std::string& t) + { + tag_ = t; + return *this; + } + std::string tag() const { return tag_; } + + private: + std::vector decls_; + std::string tag_; +}; + +// ============================================================================= +// GroupedConvKernelSetRegistry +// ============================================================================= + +class GroupedConvKernelSetRegistry +{ + public: + static GroupedConvKernelSetRegistry& instance() + { + static GroupedConvKernelSetRegistry reg; + return reg; + } + + void add(const std::string& name, const GroupedConvKernelSet& set) + { + sets_[name] = set; + if(std::find(order_.begin(), order_.end(), name) == order_.end()) + { + order_.push_back(name); + } + } + + // Alias for add() for consistency with GEMM API + void register_set(const std::string& name, const GroupedConvKernelSet& set) { add(name, set); } + + const GroupedConvKernelSet& get(const std::string& name) const + { + static GroupedConvKernelSet empty; + auto it = sets_.find(name); + return it != sets_.end() ? it->second : empty; + } + + bool has(const std::string& name) const { return sets_.find(name) != sets_.end(); } + + std::vector names() const { return order_; } + size_t size() const { return sets_.size(); } + + void clear() + { + sets_.clear(); + order_.clear(); + } + + void print() const + { + std::cout << "Grouped Conv Kernel Sets (" << size() << "):\n"; + for(const auto& name : order_) + { + const auto& set = sets_.at(name); + std::cout << " " << name << ": " << set.size() << " declarations\n"; + } + } + + private: + GroupedConvKernelSetRegistry() = default; + std::unordered_map sets_; + std::vector order_; +}; + +// ============================================================================= +// Static Registrar +// ============================================================================= + +struct GroupedConvKernelSetRegistrar +{ + GroupedConvKernelSetRegistrar(const std::string& name, const GroupedConvKernelSet& set) + { + GroupedConvKernelSetRegistry::instance().add(name, set); + } +}; + +} // namespace grouped_conv_decl + +// Convenience aliases +using GroupedConvSignature = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgorithm = grouped_conv_decl::GroupedConvAlgorithm; +using GroupedConvKernelDecl = grouped_conv_decl::GroupedConvKernelDecl; +using GroupedConvKernelSet = grouped_conv_decl::GroupedConvKernelSet; +using GroupedConvKernelSetRegistry = grouped_conv_decl::GroupedConvKernelSetRegistry; + +} // namespace dispatcher +} // namespace ck_tile + +// ============================================================================= +// Declaration Macros +// ============================================================================= + +#define CK_GROUPED_CONV_DECL_CAT_(a, b) CK_GROUPED_CONV_DECL_CAT_IMPL_(a, b) +#define CK_GROUPED_CONV_DECL_CAT_IMPL_(a, b) a##b + +// Note: __extension__ suppresses warnings about __COUNTER__ being a GCC/Clang extension +#define DECL_GROUPED_CONV_KERNEL_SET(name, ...) \ + __extension__ static ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSetRegistrar \ + CK_GROUPED_CONV_DECL_CAT_(_gconv_kset_reg_, __COUNTER__)( \ + #name, \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSet() __VA_ARGS__.tag(#name)) + +#define DECL_GROUPED_CONV_KERNEL_ALL(dtype, layout) \ + __extension__ static ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSetRegistrar \ + CK_GROUPED_CONV_DECL_CAT_(_gconv_kset_reg_, __COUNTER__)( \ + #dtype "_" #layout "_all", \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSet().add( \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvSignature().dtype(#dtype).layout( \ + #layout), \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvAlgorithm(), \ + "*")) + +#define GROUPED_CONV_KERNEL_SET(name) \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSet name +#define BEGIN_GROUPED_CONV_KERNEL_SET() \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSet() diff --git a/dispatcher/include/ck_tile/dispatcher/grouped_conv_problem.hpp b/dispatcher/include/ck_tile/dispatcher/grouped_conv_problem.hpp new file mode 100644 index 0000000000..5b58f37206 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/grouped_conv_problem.hpp @@ -0,0 +1,255 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file grouped_conv_problem.hpp + * @brief Grouped Convolution problem definition + */ + +#pragma once + +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/** + * @brief Grouped Convolution operation type + */ +enum class GroupedConvOp +{ + Forward, // Y = Conv(X, W) + BackwardData, // dX = ConvBwdData(dY, W) + BackwardWeight // dW = ConvBwdWeight(X, dY) +}; + +/** + * @brief Grouped Convolution problem specification + */ +struct GroupedConvProblem +{ + // Batch and channels + std::int64_t N; // Batch size + std::int64_t C; // Input channels + std::int64_t K; // Output channels (filters) + std::int64_t G; // Number of groups (1 for standard conv) + + // Spatial dimensions (supports 1D, 2D, 3D) + std::array input_spatial; // {D, H, W} or {1, H, W} for 2D + std::array filter_spatial; // {Z, Y, X} or {1, Y, X} for 2D + std::array output_spatial; // {Do, Ho, Wo} or {1, Ho, Wo} for 2D + + // Convolution parameters + std::array stride; // Stride in each dimension + std::array padding; // Padding in each dimension + std::array dilation; // Dilation in each dimension + + // Operation type + GroupedConvOp op = GroupedConvOp::Forward; + + // Split-K for backward weight (k_batch parameter in CK Tile). + // Values > 1 split the reduction dimension across multiple thread blocks + // and use atomic accumulation. + int split_k = 1; + + // Default constructor for 2D convolution + GroupedConvProblem() + : N(1), + C(64), + K(64), + G(1), + input_spatial{1, 28, 28}, + filter_spatial{1, 3, 3}, + output_spatial{1, 26, 26}, + stride{1, 1, 1}, + padding{0, 0, 0}, + dilation{1, 1, 1}, + op(GroupedConvOp::Forward) + { + } + + // Constructor for 2D convolution + GroupedConvProblem(std::int64_t n, + std::int64_t c, + std::int64_t k, + std::int64_t hi, + std::int64_t wi, + std::int64_t y, + std::int64_t x, + std::int64_t stride_h = 1, + std::int64_t stride_w = 1, + std::int64_t pad_h = 0, + std::int64_t pad_w = 0, + std::int64_t dilation_h = 1, + std::int64_t dilation_w = 1) + : N(n), + C(c), + K(k), + G(1), + input_spatial{1, hi, wi}, + filter_spatial{1, y, x}, + stride{1, stride_h, stride_w}, + padding{0, pad_h, pad_w}, + dilation{1, dilation_h, dilation_w}, + op(GroupedConvOp::Forward) + { + compute_output_size(); + } + + /// Check if problem dimensions are valid + bool is_valid() const + { + return N > 0 && C > 0 && K > 0 && G > 0 && (C % G == 0) && (K % G == 0); + } + + /// Compute output spatial dimensions + void compute_output_size() + { + for(int i = 0; i < 3; ++i) + { + std::int64_t effective_filter = (filter_spatial[i] - 1) * dilation[i] + 1; + output_spatial[i] = + (input_spatial[i] + 2 * padding[i] - effective_filter) / stride[i] + 1; + } + } + + /// Get 2D height/width accessors + std::int64_t Hi() const { return input_spatial[1]; } + std::int64_t Wi() const { return input_spatial[2]; } + std::int64_t Ho() const { return output_spatial[1]; } + std::int64_t Wo() const { return output_spatial[2]; } + std::int64_t Y() const { return filter_spatial[1]; } // Filter height + std::int64_t X() const { return filter_spatial[2]; } // Filter width + + /// Get total FLOPs for this convolution + double get_flops() const + { + // Forward: 2 * N * K * Ho * Wo * C * Y * X / G + double spatial_out = 1.0; + double filter_size = 1.0; + for(int i = 0; i < 3; ++i) + { + spatial_out *= output_spatial[i]; + filter_size *= filter_spatial[i]; + } + return 2.0 * N * K * spatial_out * (C / G) * filter_size; + } + + /// Check if this is a depthwise convolution + bool is_depthwise() const { return G == C && G == K; } + + /// Check if this is a pointwise (1x1) convolution + bool is_pointwise() const + { + return filter_spatial[0] == 1 && filter_spatial[1] == 1 && filter_spatial[2] == 1; + } + + /// String representation + std::string to_string() const + { + std::string s = "GroupedConvProblem(N=" + std::to_string(N); + s += ", C=" + std::to_string(C) + ", K=" + std::to_string(K); + s += ", G=" + std::to_string(G); + s += ", Hi=" + std::to_string(Hi()) + ", Wi=" + std::to_string(Wi()); + s += ", Y=" + std::to_string(Y()) + ", X=" + std::to_string(X()); + s += ", Ho=" + std::to_string(Ho()) + ", Wo=" + std::to_string(Wo()); + s += ")"; + return s; + } +}; + +// ============================================================================= +// GroupedConvProblemBuilder +// ============================================================================= + +/// Builder pattern for Grouped Convolution problem configuration +class GroupedConvProblemBuilder +{ + public: + GroupedConvProblemBuilder() = default; + + GroupedConvProblemBuilder& batch(std::int64_t n) + { + problem_.N = n; + return *this; + } + + GroupedConvProblemBuilder& channels(std::int64_t c, std::int64_t k) + { + problem_.C = c; + problem_.K = k; + return *this; + } + + GroupedConvProblemBuilder& groups(std::int64_t g) + { + problem_.G = g; + return *this; + } + + GroupedConvProblemBuilder& input_size(std::int64_t h, std::int64_t w) + { + problem_.input_spatial[0] = 1; + problem_.input_spatial[1] = h; + problem_.input_spatial[2] = w; + return *this; + } + + GroupedConvProblemBuilder& filter_size(std::int64_t y, std::int64_t x) + { + problem_.filter_spatial[0] = 1; + problem_.filter_spatial[1] = y; + problem_.filter_spatial[2] = x; + return *this; + } + + GroupedConvProblemBuilder& stride(std::int64_t sh, std::int64_t sw) + { + problem_.stride[0] = 1; + problem_.stride[1] = sh; + problem_.stride[2] = sw; + return *this; + } + + GroupedConvProblemBuilder& padding(std::int64_t ph, std::int64_t pw) + { + problem_.padding[0] = 0; + problem_.padding[1] = ph; + problem_.padding[2] = pw; + return *this; + } + + GroupedConvProblemBuilder& dilation(std::int64_t dh, std::int64_t dw) + { + problem_.dilation[0] = 1; + problem_.dilation[1] = dh; + problem_.dilation[2] = dw; + return *this; + } + + GroupedConvProblemBuilder& operation(GroupedConvOp op) + { + problem_.op = op; + return *this; + } + + [[nodiscard]] GroupedConvProblem build() const + { + GroupedConvProblem p = problem_; + p.compute_output_size(); + if(!p.is_valid()) + { + throw std::invalid_argument("Invalid grouped convolution problem dimensions"); + } + return p; + } + + private: + GroupedConvProblem problem_; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/grouped_conv_registry.hpp b/dispatcher/include/ck_tile/dispatcher/grouped_conv_registry.hpp new file mode 100644 index 0000000000..42698a0bc8 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/grouped_conv_registry.hpp @@ -0,0 +1,614 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file grouped_conv_registry.hpp + * @brief Grouped Convolution kernel registry and dispatcher + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/base_registry.hpp" +#include "ck_tile/dispatcher/dispatcher_error.hpp" +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp" + +namespace ck_tile { +namespace dispatcher { + +// ============================================================================= +// Thread-local buffer context for GroupedConvDispatcher::run() +// The generated conv backend RunFn reads these to get buffer pointers. +// ============================================================================= + +struct ConvDispatchBuffers +{ + const void* input_ptr = nullptr; + const void* weight_ptr = nullptr; + void* output_ptr = nullptr; + int warmup = 3; + int repeat = 10; + bool benchmarking = true; + int split_k = 1; +}; + +inline thread_local ConvDispatchBuffers g_conv_dispatch_buffers; + +// ============================================================================= +// GroupedConvKernelKey - Unique identifier for a grouped convolution kernel +// ============================================================================= + +struct GroupedConvKernelKey +{ + // Signature fields + std::string dtype_in; + std::string dtype_wei; + std::string dtype_out; + std::string layout; // e.g., "nhwgc" + int ndim_spatial = 2; // 1, 2, or 3 + GroupedConvOp op = GroupedConvOp::Forward; + + // Tile configuration + int tile_m = 1; + int tile_n = 128; + int tile_k = 128; + + // Wave/warp configuration + int wave_m = 2; + int wave_n = 2; + int wave_k = 1; + int warp_m = 32; + int warp_n = 32; + int warp_k = 16; + + // Pipeline + std::string pipeline = "compv3"; + std::string scheduler = "intrawave"; + std::string epilogue = "cshuffle"; + + // ConvConfigBase parity fields + int vector_size_a = 4; + int vector_size_b = 8; + int vector_size_c = 8; + int block_per_cu = 1; + int num_wave_groups = 1; + int num_groups_to_merge = 1; + + // GPU architecture (for filter_by_arch) + std::string arch = "gfx942"; + + bool operator==(const GroupedConvKernelKey& other) const + { + return dtype_in == other.dtype_in && dtype_wei == other.dtype_wei && + dtype_out == other.dtype_out && layout == other.layout && + ndim_spatial == other.ndim_spatial && op == other.op && tile_m == other.tile_m && + tile_n == other.tile_n && tile_k == other.tile_k && wave_m == other.wave_m && + wave_n == other.wave_n && wave_k == other.wave_k && warp_m == other.warp_m && + warp_n == other.warp_n && warp_k == other.warp_k && pipeline == other.pipeline && + scheduler == other.scheduler && epilogue == other.epilogue && + vector_size_a == other.vector_size_a && vector_size_b == other.vector_size_b && + vector_size_c == other.vector_size_c && block_per_cu == other.block_per_cu && + num_wave_groups == other.num_wave_groups && + num_groups_to_merge == other.num_groups_to_merge && arch == other.arch; + } + + std::string to_string() const + { + std::string op_str; + switch(op) + { + case GroupedConvOp::Forward: op_str = "fwd"; break; + case GroupedConvOp::BackwardData: op_str = "bwd_data"; break; + case GroupedConvOp::BackwardWeight: op_str = "bwd_weight"; break; + } + return "grouped_conv_" + op_str + "_" + dtype_in + "_" + std::to_string(ndim_spatial) + + "d_" + std::to_string(tile_m) + "x" + std::to_string(tile_n) + "x" + + std::to_string(tile_k) + "_" + std::to_string(wave_m) + "x" + + std::to_string(wave_n) + "x" + std::to_string(wave_k) + "_" + + std::to_string(warp_m) + "x" + std::to_string(warp_n) + "x" + + std::to_string(warp_k) + "_" + pipeline; + } +}; + +struct GroupedConvKernelKeyHash +{ + std::size_t operator()(const GroupedConvKernelKey& key) const + { + std::size_t h = std::hash{}(key.dtype_in); + h ^= std::hash{}(key.layout) << 1; + h ^= std::hash{}(key.ndim_spatial) << 2; + h ^= std::hash{}(static_cast(key.op)) << 3; + h ^= std::hash{}(key.tile_m) << 4; + h ^= std::hash{}(key.tile_n) << 5; + h ^= std::hash{}(key.tile_k) << 6; + h ^= std::hash{}(key.wave_m) << 7; + h ^= std::hash{}(key.wave_n) << 8; + h ^= std::hash{}(key.warp_m) << 9; + h ^= std::hash{}(key.warp_n) << 10; + h ^= std::hash{}(key.pipeline) << 11; + h ^= std::hash{}(key.arch) << 12; + return h; + } +}; + +// ============================================================================= +// GroupedConvKernelInstance - Runtime representation of a kernel +// ============================================================================= + +// Forward declaration for shared_ptr type alias +class GroupedConvKernelInstance; +using GroupedConvKernelInstancePtr = std::shared_ptr; + +class GroupedConvKernelInstance +{ + public: + using RunFn = std::function; + + GroupedConvKernelInstance(const GroupedConvKernelKey& key, + const std::string& name, + RunFn run_fn) + : key_(key), name_(name), run_fn_(std::move(run_fn)) + { + } + + const GroupedConvKernelKey& key() const { return key_; } + const std::string& name() const { return name_; } + + float run(const GroupedConvProblem& problem, void* stream = nullptr) const + { + return run_fn_(problem, stream); + } + + bool matches(const GroupedConvProblem& problem) const + { + // Check if this kernel can handle the problem + return problem.op == key_.op; + } + + private: + GroupedConvKernelKey key_; + std::string name_; + RunFn run_fn_; +}; + +// ============================================================================= +// GroupedConvRegistry - Stores and manages grouped convolution kernels +// ============================================================================= + +class GroupedConvRegistry : public BaseRegistry +{ + using Base = BaseRegistry; + + public: + GroupedConvRegistry() = default; + + /// Singleton instance for global kernel registration + static GroupedConvRegistry& instance() + { + static GroupedConvRegistry registry; + return registry; + } + + /// Register kernels from a GroupedConvKernelSet (atomic batch registration) + bool register_set(const GroupedConvKernelSet& kernel_set, Priority priority = Priority::Normal) + { + // Build all instances first, then register under a single lock hold + // so readers never see a half-registered set. + std::vector>> + batch; + batch.reserve(kernel_set.declarations().size()); + + for(const auto& decl : kernel_set.declarations()) + { + GroupedConvKernelKey key; + key.dtype_in = decl.signature.dtype_in_; + key.dtype_wei = decl.signature.dtype_wei_; + key.dtype_out = decl.signature.dtype_out_; + key.layout = decl.signature.layout_; + key.ndim_spatial = decl.signature.num_dims_; + key.op = (decl.signature.conv_op_ == "forward") ? GroupedConvOp::Forward + : (decl.signature.conv_op_ == "bwd_data") ? GroupedConvOp::BackwardData + : GroupedConvOp::BackwardWeight; + key.tile_m = decl.algorithm.tile_m_; + key.tile_n = decl.algorithm.tile_n_; + key.tile_k = decl.algorithm.tile_k_; + key.wave_m = decl.algorithm.wave_m_; + key.wave_n = decl.algorithm.wave_n_; + key.wave_k = decl.algorithm.wave_k_; + key.warp_m = decl.algorithm.warp_m_; + key.warp_n = decl.algorithm.warp_n_; + key.warp_k = decl.algorithm.warp_k_; + key.pipeline = decl.algorithm.pipeline_; + key.scheduler = decl.algorithm.scheduler_; + key.epilogue = decl.algorithm.epilogue_; + key.vector_size_a = decl.algorithm.vector_a_; + key.vector_size_b = decl.algorithm.vector_b_; + key.vector_size_c = decl.algorithm.vector_c_; + key.block_per_cu = decl.algorithm.block_per_cu_; + key.num_wave_groups = decl.algorithm.num_wave_groups_; + key.num_groups_to_merge = decl.algorithm.num_groups_to_merge_; + key.arch = decl.arch; + + batch.emplace_back(key, + std::make_shared( + key, decl.name(), [](const GroupedConvProblem&, void*) -> float { + return 0.0f; + })); + } + + std::lock_guard lock(mutex()); + bool any_registered = false; + for(auto& [key, instance] : batch) + { + auto it = entries().find(key); + if(it == entries().end() || it->second.priority <= priority) + { + entries_mut()[key] = typename Base::Entry{std::move(instance), priority}; + any_registered = true; + } + } + return any_registered; + } + + /// Find the best kernel for a problem + const GroupedConvKernelInstance* find(const GroupedConvProblem& problem) const + { + std::lock_guard lock(mutex()); + const GroupedConvKernelInstance* best = nullptr; + Priority best_priority = Priority::Low; + + for(const auto& [key, entry] : entries()) + { + if(entry.instance->matches(problem)) + { + if(!best || entry.priority > best_priority) + { + best = entry.instance.get(); + best_priority = entry.priority; + } + } + } + + return best; + } + + /// Get all registered kernels + std::vector all_kernels() const + { + std::lock_guard lock(mutex()); + std::vector result; + for(const auto& [key, entry] : entries()) + { + result.push_back(entry.instance.get()); + } + return result; + } + + /// Export registry to JSON string + std::string export_json(bool include_statistics = false) const + { + // Note: get_name() acquires the mutex internally, so we must NOT hold + // the registry mutex here (std::mutex is not recursive). + std::string reg_name = get_name(); + + std::lock_guard lock(mutex()); + std::ostringstream json; + + json << "{\n"; + json << " \"metadata\": {\n"; + json << " \"registry_name\": \"" << json_escape(reg_name) << "\",\n"; + json << " \"total_kernels\": " << entries().size() << "\n"; + json << " }"; + + if(include_statistics && !entries().empty()) + { + std::map by_datatype; + std::map by_pipeline; + std::map by_arch; + + for(const auto& [key, entry] : entries()) + { + std::string dtype_key = key.dtype_in + "_" + key.dtype_wei + "_" + key.dtype_out; + by_datatype[dtype_key]++; + by_pipeline[key.pipeline]++; + by_arch[key.arch]++; + } + + json << ",\n \"statistics\": {\n"; + json << " \"by_datatype\": {"; + bool first = true; + for(const auto& [dtype, count] : by_datatype) + { + if(!first) + json << ","; + json << "\"" << json_escape(dtype) << "\":" << count; + first = false; + } + json << "},\n"; + json << " \"by_pipeline\": {"; + first = true; + for(const auto& [pipeline, count] : by_pipeline) + { + if(!first) + json << ","; + json << "\"" << json_escape(pipeline) << "\":" << count; + first = false; + } + json << "},\n"; + json << " \"by_arch\": {"; + first = true; + for(const auto& [arch, count] : by_arch) + { + if(!first) + json << ","; + json << "\"" << json_escape(arch) << "\":" << count; + first = false; + } + json << "}\n }"; + } + + json << ",\n \"kernels\": [\n"; + bool first = true; + for(const auto& [key, entry] : entries()) + { + if(!first) + json << ",\n"; + json << " " << export_kernel_json(*entry.instance); + first = false; + } + json << "\n ]\n"; + json << "}\n"; + + return json.str(); + } + + /// Export registry to JSON file + void export_json_to_file(const std::string& filename, bool include_statistics = false) const + { + std::string json_str = export_json(include_statistics); + std::ofstream file(filename); + if(!file.is_open()) + { + throw std::runtime_error("Failed to open file for export: " + filename); + } + file << json_str; + } + + /// Get kernels matching a predicate + std::vector + filter(std::function predicate) const + { + std::lock_guard lock(mutex()); + std::vector result; + for(const auto& [key, entry] : entries()) + { + if(predicate(*entry.instance)) + { + result.push_back(entry.instance.get()); + } + } + return result; + } + + /// Remove kernels not matching the arch + std::size_t filter_by_arch(const std::string& gpu_arch) + { + std::lock_guard lock(mutex()); + std::vector to_remove; + for(const auto& [key, entry] : entries()) + { + if(key.arch != gpu_arch) + { + to_remove.push_back(key); + } + } + for(const auto& key : to_remove) + { + entries_mut().erase(key); + } + return to_remove.size(); + } + + private: + static std::string json_escape(const std::string& str) + { + std::ostringstream oss; + for(char c : str) + { + switch(c) + { + case '"': oss << "\\\""; break; + case '\\': oss << "\\\\"; break; + case '\b': oss << "\\b"; break; + case '\f': oss << "\\f"; break; + case '\n': oss << "\\n"; break; + case '\r': oss << "\\r"; break; + case '\t': oss << "\\t"; break; + default: + if(c < 0x20) + { + oss << "\\u" << std::hex << std::setw(4) << std::setfill('0') << (int)c; + } + else + { + oss << c; + } + } + } + return oss.str(); + } + + static std::string export_kernel_json(const GroupedConvKernelInstance& kernel) + { + std::ostringstream json; + const auto& key = kernel.key(); + + std::string op_str; + switch(key.op) + { + case GroupedConvOp::Forward: op_str = "fwd"; break; + case GroupedConvOp::BackwardData: op_str = "bwd_data"; break; + case GroupedConvOp::BackwardWeight: op_str = "bwd_weight"; break; + } + + json << "{\n"; + json << " \"name\": \"" << json_escape(kernel.name()) << "\",\n"; + json << " \"signature\": {\n"; + json << " \"dtype_in\": \"" << json_escape(key.dtype_in) << "\",\n"; + json << " \"dtype_wei\": \"" << json_escape(key.dtype_wei) << "\",\n"; + json << " \"dtype_out\": \"" << json_escape(key.dtype_out) << "\",\n"; + json << " \"layout\": \"" << json_escape(key.layout) << "\",\n"; + json << " \"ndim_spatial\": " << key.ndim_spatial << ",\n"; + json << " \"op\": \"" << op_str << "\"\n"; + json << " },\n"; + json << " \"algorithm\": {\n"; + json << " \"tile_m\": " << key.tile_m << ",\n"; + json << " \"tile_n\": " << key.tile_n << ",\n"; + json << " \"tile_k\": " << key.tile_k << ",\n"; + json << " \"wave\": \"" << key.wave_m << "x" << key.wave_n << "x" << key.wave_k + << "\",\n"; + json << " \"warp\": \"" << key.warp_m << "x" << key.warp_n << "x" << key.warp_k + << "\",\n"; + json << " \"pipeline\": \"" << json_escape(key.pipeline) << "\",\n"; + json << " \"scheduler\": \"" << json_escape(key.scheduler) << "\",\n"; + json << " \"epilogue\": \"" << json_escape(key.epilogue) << "\",\n"; + json << " \"vector_sizes\": [" << key.vector_size_a << "," << key.vector_size_b + << "," << key.vector_size_c << "],\n"; + json << " \"block_per_cu\": " << key.block_per_cu << ",\n"; + json << " \"num_wave_groups\": " << key.num_wave_groups << ",\n"; + json << " \"num_groups_to_merge\": " << key.num_groups_to_merge << "\n"; + json << " },\n"; + json << " \"arch\": \"" << json_escape(key.arch) << "\"\n"; + json << " }"; + + return json.str(); + } +}; + +// ============================================================================= +// GroupedConvDispatcher - Selects and runs the best kernel for a problem +// ============================================================================= + +class GroupedConvDispatcher +{ + public: + enum class SelectionStrategy + { + PriorityBased, + Heuristic + }; + + using HeuristicFunction = std::function(const GroupedConvProblem&)>; + + explicit GroupedConvDispatcher(GroupedConvRegistry* registry) + : registry_(registry), strategy_(SelectionStrategy::PriorityBased) + { + } + + void set_strategy(SelectionStrategy s) { strategy_ = s; } + void set_heuristic(HeuristicFunction fn) { heuristic_ = std::move(fn); } + + /// Select the best kernel for a problem (does not run it) + const GroupedConvKernelInstance* select_kernel(const GroupedConvProblem& problem) const + { + if(strategy_ == SelectionStrategy::Heuristic) + return select_heuristic(problem); + return registry_->find(problem); + } + + /// Run convolution with automatic kernel selection (legacy - no buffers) + float run(const GroupedConvProblem& problem, void* stream = nullptr) + { + const auto* kernel = select_kernel(problem); + if(!kernel) + { + throw NoKernelFound("No suitable grouped convolution kernel found for problem: " + + problem.to_string()); + } + return kernel->run(problem, stream); + } + + /// Run convolution with buffer pointers and automatic kernel selection. + /// Sets the thread-local buffer context before dispatching to the kernel. + float run(const void* input_ptr, + const void* weight_ptr, + void* output_ptr, + const GroupedConvProblem& problem, + void* stream = nullptr, + int warmup = 3, + int repeat = 10) + { + const auto* kernel = select_kernel(problem); + if(!kernel) + { + throw NoKernelFound("No suitable grouped convolution kernel found for problem: " + + problem.to_string()); + } + g_conv_dispatch_buffers.input_ptr = input_ptr; + g_conv_dispatch_buffers.weight_ptr = weight_ptr; + g_conv_dispatch_buffers.output_ptr = output_ptr; + g_conv_dispatch_buffers.warmup = warmup; + g_conv_dispatch_buffers.repeat = repeat; + g_conv_dispatch_buffers.benchmarking = benchmarking_; + g_conv_dispatch_buffers.split_k = problem.split_k; + return kernel->run(problem, stream); + } + + /// Enable or disable GPU benchmarking (timing). + /// When disabled, kernels execute once with no timing overhead. + void set_benchmarking(bool enable) { benchmarking_ = enable; } + [[nodiscard]] bool benchmarking_enabled() const { return benchmarking_; } + + /// Alias kept for backward compatibility + const GroupedConvKernelInstance* select(const GroupedConvProblem& problem) const + { + return select_kernel(problem); + } + + private: + const GroupedConvKernelInstance* select_heuristic(const GroupedConvProblem& problem) const + { + if(!heuristic_) + return registry_->find(problem); + + auto ranked_names = heuristic_(problem); + auto all = registry_->all_kernels(); + for(const auto& name : ranked_names) + { + for(const auto* kernel : all) + { + if(kernel->name().find(name) != std::string::npos && kernel->matches(problem)) + { + return kernel; + } + } + } + return registry_->find(problem); + } + + GroupedConvRegistry* registry_; + SelectionStrategy strategy_; + HeuristicFunction heuristic_; + bool benchmarking_ = true; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/grouped_conv_utils.hpp b/dispatcher/include/ck_tile/dispatcher/grouped_conv_utils.hpp new file mode 100644 index 0000000000..c817d36673 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/grouped_conv_utils.hpp @@ -0,0 +1,324 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file grouped_conv_utils.hpp + * @brief CK Tile Grouped Convolution Dispatcher Utilities + */ + +#pragma once + +#include "ck_tile/dispatcher/grouped_conv_config.hpp" +#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp" +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" +#include "ck_tile/dispatcher/arch_filter.hpp" +#include "ck_tile/dispatcher/utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +namespace grouped_conv_utils { + +inline GroupedConvKernelDecl create_grouped_conv2d_fwd(const std::string& dtype = "fp16", + int tile_n = 128, + int tile_k = 128, + const std::string& arch = "gfx942") +{ + return GroupedConvKernelDecl( + GroupedConvSig().dtype(dtype).layout("nhwc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, tile_n, tile_k) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv4") + .vector_sizes(4, 8, 8), + arch); +} + +inline GroupedConvKernelDecl create_grouped_conv3d_fwd(const std::string& dtype = "fp16", + int tile_n = 64, + int tile_k = 64, + const std::string& arch = "gfx942") +{ + return GroupedConvKernelDecl( + GroupedConvSig().dtype(dtype).layout("ndhwc").conv_type("forward").dims(3), + GroupedConvAlgo() + .tile(1, tile_n, tile_k) + .wave(2, 2, 1) + .warp(16, 16, 32) + .pipeline("compv3") + .vector_sizes(4, 8, 8), + arch); +} + +inline GroupedConvKernelDecl create_grouped_conv2d_bwd_data(const std::string& dtype = "fp16", + int tile_n = 128, + int tile_k = 128, + const std::string& arch = "gfx942") +{ + return GroupedConvKernelDecl( + GroupedConvSig().dtype(dtype).layout("nhwc").conv_type("bwd_data").dims(2), + GroupedConvAlgo() + .tile(1, tile_n, tile_k) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .vector_sizes(4, 8, 8), + arch); +} + +inline GroupedConvKernelDecl create_grouped_conv2d_bwd_weight(const std::string& dtype = "fp16", + int tile_n = 128, + int tile_k = 128, + const std::string& arch = "gfx942") +{ + return GroupedConvKernelDecl( + GroupedConvSig().dtype(dtype).layout("nhwc").conv_type("bwd_weight").dims(2), + GroupedConvAlgo() + .tile(1, tile_n, tile_k) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .memory_op("atomic_add") + .vector_sizes(4, 8, 8), + arch); +} + +inline GroupedConvProblem create_grouped_conv2d_problem(int N, + int C, + int K, + int Hi, + int Wi, + int Y, + int X, + int stride = 1, + int padding = 0, + GroupedConvOp op = GroupedConvOp::Forward) +{ + GroupedConvProblem p; + p.N = N; + p.C = C; + p.K = K; + p.G = 1; + p.input_spatial = {1, Hi, Wi}; + p.filter_spatial = {1, Y, X}; + p.stride = {1, stride, stride}; + p.padding = {0, padding, padding}; + p.dilation = {1, 1, 1}; + p.op = op; + p.compute_output_size(); + return p; +} + +inline GroupedConvProblem create_grouped_conv3d_problem(int N, + int C, + int K, + int Di, + int Hi, + int Wi, + int Z, + int Y, + int X, + int stride = 1, + int padding = 0, + GroupedConvOp op = GroupedConvOp::Forward) +{ + GroupedConvProblem p; + p.N = N; + p.C = C; + p.K = K; + p.G = 1; + p.input_spatial = {Di, Hi, Wi}; + p.filter_spatial = {Z, Y, X}; + p.stride = {stride, stride, stride}; + p.padding = {padding, padding, padding}; + p.dilation = {1, 1, 1}; + p.op = op; + p.compute_output_size(); + return p; +} + +inline GroupedConvProblem create_depthwise_grouped_conv2d_problem( + int N, int C, int Hi, int Wi, int Y, int X, int stride = 1, int padding = 0) +{ + GroupedConvProblem p; + p.N = N; + p.C = C; + p.K = C; + p.G = C; + p.input_spatial = {1, Hi, Wi}; + p.filter_spatial = {1, Y, X}; + p.stride = {1, stride, stride}; + p.padding = {0, padding, padding}; + p.dilation = {1, 1, 1}; + p.op = GroupedConvOp::Forward; + p.compute_output_size(); + return p; +} + +inline void print_pattern_docs(std::ostream& os = std::cout) +{ + os << "Grouped Convolution Pattern Documentation\n"; + os << "==========================================\n"; + os << "Signature patterns: dtype, layout, conv_type (forward/bwd_data/bwd_weight), dims " + "(2/3)\n"; + os << "Algorithm patterns: tile(M,N,K), wave(M,N,K), warp(M,N,K), pipeline, vector_sizes\n"; + os << "Arch patterns: gfx942, gfx90a, gfx950, or '*' for all\n"; +} + +inline void print_grouped_conv_kernel_decl(const GroupedConvKernelDecl& decl, + std::ostream& os = std::cout) +{ + os << "GroupedConvKernelDecl: " << decl.name() << "\n"; + os << " Signature: dtype=" << decl.signature.dtype_in_ << ", layout=" << decl.signature.layout_ + << ", conv_type=" << decl.signature.conv_op_ << ", dims=" << decl.signature.num_dims_ + << "\n"; + os << " Algorithm: tile=" << decl.algorithm.tile_m_ << "x" << decl.algorithm.tile_n_ << "x" + << decl.algorithm.tile_k_ << ", wave=" << decl.algorithm.wave_m_ << "x" + << decl.algorithm.wave_n_ << "x" << decl.algorithm.wave_k_ + << ", warp=" << decl.algorithm.warp_m_ << "x" << decl.algorithm.warp_n_ << "x" + << decl.algorithm.warp_k_ << ", pipeline=" << decl.algorithm.pipeline_ << "\n"; + os << " Arch: " << decl.arch << "\n"; +} + +inline void print_grouped_conv_problem(const GroupedConvProblem& p, std::ostream& os = std::cout) +{ + os << p.to_string() << "\n"; + os << " FLOPs: " << std::scientific << p.get_flops() << "\n"; +} + +inline GroupedConvKernelSet build_grouped_conv2d_fwd_set(const std::string& dtype = "fp16", + const std::string& arch = "gfx942") +{ + GroupedConvKernelSet set; + auto decl1 = create_grouped_conv2d_fwd(dtype, 128, 128, arch); + set.add(decl1.signature, decl1.algorithm, decl1.arch); + auto decl2 = create_grouped_conv2d_fwd(dtype, 256, 256, arch); + set.add(decl2.signature, decl2.algorithm, decl2.arch); + return set; +} + +inline GroupedConvKernelSet build_grouped_conv2d_full_set(const std::string& dtype = "fp16", + const std::string& arch = "gfx942") +{ + GroupedConvKernelSet set; + set.merge(build_grouped_conv2d_fwd_set(dtype, arch)); + auto bwd_data = create_grouped_conv2d_bwd_data(dtype, 128, 128, arch); + set.add(bwd_data.signature, bwd_data.algorithm, bwd_data.arch); + auto bwd_weight = create_grouped_conv2d_bwd_weight(dtype, 128, 128, arch); + set.add(bwd_weight.signature, bwd_weight.algorithm, bwd_weight.arch); + return set; +} + +struct ValidationResult +{ + bool passed = false; + float max_abs_diff = 0.0f; + float max_rel_diff = 0.0f; + float rtol = 1e-3f; + float atol = 1e-3f; + + void print(std::ostream& os = std::cout) const + { + os << "ValidationResult: " << (passed ? "PASSED" : "FAILED") << "\n"; + os << " max_abs_diff: " << max_abs_diff << ", max_rel_diff: " << max_rel_diff << "\n"; + os << " rtol: " << rtol << ", atol: " << atol << "\n"; + } +}; + +template +inline ValidationResult validate_buffers( + const T* result, const T* reference, size_t count, float rtol = 1e-3f, float atol = 1e-3f) +{ + ValidationResult vr; + vr.rtol = rtol; + vr.atol = atol; + vr.passed = true; + + for(size_t i = 0; i < count; ++i) + { + float r = static_cast(result[i]); + float ref = static_cast(reference[i]); + float abs_diff = std::abs(r - ref); + float rel_diff = (std::abs(ref) > 1e-10f) ? (abs_diff / std::abs(ref)) : 0.0f; + + vr.max_abs_diff = std::max(vr.max_abs_diff, abs_diff); + vr.max_rel_diff = std::max(vr.max_rel_diff, rel_diff); + + float threshold = atol + rtol * std::abs(ref); + if(abs_diff > threshold) + { + vr.passed = false; + } + } + + return vr; +} + +struct BenchmarkResult +{ + std::string kernel_name; + float time_ms = 0.0f; + float tflops = 0.0f; + int warmup_runs = 0; + int benchmark_runs = 0; + + void print(std::ostream& os = std::cout) const + { + os << "BenchmarkResult: " << kernel_name << "\n"; + os << " time_ms: " << time_ms << ", tflops: " << tflops << "\n"; + os << " warmup_runs: " << warmup_runs << ", benchmark_runs: " << benchmark_runs << "\n"; + } +}; + +inline float calc_tflops(double flops, float time_ms) +{ + return static_cast(flops / (time_ms * 1e9)); +} + +inline double calculate_conv_tflops(const GroupedConvProblem& problem, double time_ms) +{ + return problem.get_flops() / (time_ms * 1e9); +} + +} // namespace grouped_conv_utils + +namespace examples { +inline int basic_grouped_conv_example_main(const std::string& example_name) +{ + std::cout << "=== " << example_name << " ===\n"; + + // Create a grouped convolution problem + auto problem = grouped_conv_utils::create_grouped_conv2d_problem( + 32, 64, 128, 28, 28, 3, 3, 1, 1, GroupedConvOp::Forward); + + grouped_conv_utils::print_grouped_conv_problem(problem); + + // Create and print a kernel declaration + auto decl = grouped_conv_utils::create_grouped_conv2d_fwd("fp16", 128, 128, "gfx942"); + grouped_conv_utils::print_grouped_conv_kernel_decl(decl); + + // Build and print kernel set + auto kernel_set = grouped_conv_utils::build_grouped_conv2d_fwd_set("fp16", "gfx942"); + kernel_set.print(); + + return 0; +} +} // namespace examples + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/problem.hpp b/dispatcher/include/ck_tile/dispatcher/problem.hpp index 437511d1ba..5bffb56b49 100644 --- a/dispatcher/include/ck_tile/dispatcher/problem.hpp +++ b/dispatcher/include/ck_tile/dispatcher/problem.hpp @@ -98,7 +98,7 @@ struct Problem /** * Create Problem by inferring MNK from tensor shapes. * - * For GEMM: C[M,N] = A[M,K] × B[K,N] + * For GEMM: C[M,N] = A[M,K] x B[K,N] * * @param a_shape Shape of matrix A (M x K, or K x M if transposed) * @param b_shape Shape of matrix B (K x N, or N x K if transposed) @@ -113,7 +113,7 @@ struct Problem [[nodiscard]] static Problem from_shapes(TensorShape a_shape, TensorShape b_shape, TensorShape c_shape) { - // For C = A × B: + // For C = A x B: // A: [M, K] (or [K, M] if transposed) // B: [K, N] (or [N, K] if transposed) // C: [M, N] @@ -164,7 +164,7 @@ struct Problem * @throws std::invalid_argument if dimensions are inconsistent * * Example: - * // A[512,256] × B[256,1024] = C[512,1024] + * // A[512,256] x B[256,1024] = C[512,1024] * auto problem = Problem::from_dimensions(512, 256, 256, 1024, 512, 1024); */ [[nodiscard]] static Problem from_dimensions(std::int64_t a_rows, @@ -188,7 +188,7 @@ struct Problem * @throws std::invalid_argument if K dimensions don't match * * Example: - * // A[512,256] × B[256,1024] = C[512,1024] + * // A[512,256] x B[256,1024] = C[512,1024] * auto problem = Problem::from_ab(512, 256, 256, 1024); */ [[nodiscard]] static Problem diff --git a/dispatcher/include/ck_tile/dispatcher/registry.hpp b/dispatcher/include/ck_tile/dispatcher/registry.hpp index 93d1eb9f64..4f34e589ea 100644 --- a/dispatcher/include/ck_tile/dispatcher/registry.hpp +++ b/dispatcher/include/ck_tile/dispatcher/registry.hpp @@ -7,38 +7,20 @@ * Central registry for all available kernel instances with priority-based * ordering and efficient lookup. * - * Features: - * - Thread-safe registration and lookup - * - Priority-based ordering (High, Normal, Low) - * - Lookup by name or KernelKey - * - Filter by problem compatibility - * - Supports both singleton and multiple instance patterns - * - * Usage (Singleton - backward compatible): - * auto& registry = Registry::instance(); - * registry.register_kernel(kernel, Priority::High); - * auto kernel = registry.lookup("kernel_name"); - * - * Usage (Multiple registries): - * Registry fp16_registry; - * Registry bf16_registry; - * fp16_registry.register_kernel(fp16_kernel, Priority::High); - * bf16_registry.register_kernel(bf16_kernel, Priority::High); - * - * Dispatcher fp16_dispatcher(&fp16_registry); - * Dispatcher bf16_dispatcher(&bf16_registry); + * Derives from BaseRegistry for shared logic (thread safety, naming, priority, + * merge) while keeping GEMM-specific APIs (lookup by KernelKey, filter_by_arch, + * JSON export, auto-export). * * Status: Production ready, thread-safe */ #pragma once +#include "ck_tile/dispatcher/base_registry.hpp" #include "ck_tile/dispatcher/kernel_instance.hpp" #include "ck_tile/dispatcher/kernel_key.hpp" #include -#include #include -#include #include #include @@ -47,20 +29,16 @@ namespace dispatcher { /// Registry: Central mapping from kernel configurations to executable instances /// Thread-safe kernel registration and lookup -/// Supports both singleton pattern and multiple independent instances -class Registry +/// Derives from BaseRegistry for shared functionality +class Registry : public BaseRegistry { + using Base = BaseRegistry; + public: - /// Priority levels for conflict resolution when multiple kernels have same key - enum class Priority - { - Low = 0, - Normal = 1, - High = 2 - }; + // Re-export Priority from the shared enum for backward compatibility + using Priority = ck_tile::dispatcher::Priority; /// Default constructor - creates an empty registry instance - /// Use this to create independent registries for different kernel sets Registry(); /// Destructor - triggers auto-export if enabled @@ -72,106 +50,51 @@ class Registry /// Move assignment Registry& operator=(Registry&& other) noexcept; - // Prevent copying (registries contain shared_ptrs that shouldn't be duplicated) + // Prevent copying Registry(const Registry&) = delete; Registry& operator=(const Registry&) = delete; /// Register a kernel instance with the registry - /// @param instance Kernel instance to register - /// @param priority Priority level for conflict resolution (default: Normal) - /// @return true if registered successfully, false if duplicate with higher priority exists bool register_kernel(KernelInstancePtr instance, Priority priority = Priority::Normal); /// Lookup a kernel by its string identifier - /// @param identifier Kernel identifier string - /// @return Kernel instance if found, nullptr otherwise [[nodiscard]] KernelInstancePtr lookup(const std::string& identifier) const; /// Lookup a kernel by its KernelKey - /// @param key Kernel configuration key - /// @return Kernel instance if found, nullptr otherwise [[nodiscard]] KernelInstancePtr lookup(const KernelKey& key) const; /// Get all registered kernels - /// @return Vector of all kernel instances [[nodiscard]] std::vector get_all() const; /// Get all kernels matching a predicate - /// @param predicate Function to filter kernels - /// @return Vector of matching kernel instances [[nodiscard]] std::vector filter(std::function predicate) const; - /// Get number of registered kernels - [[nodiscard]] std::size_t size() const; - - /// Check if registry is empty - [[nodiscard]] bool empty() const; - - /// Clear all registered kernels - void clear(); - - /// Get registry name (for logging/debugging) - [[nodiscard]] const std::string& get_name() const; - - /// Set registry name (for logging/debugging) - void set_name(const std::string& name); + // size(), empty(), clear(), get_name(), set_name(), merge_from() inherited from Base /// Export registry to JSON string - /// @param include_statistics Whether to include kernel statistics breakdown - /// @return JSON string with all kernel metadata [[nodiscard]] std::string export_json(bool include_statistics = true) const; /// Export registry to JSON file - /// @param filename Output filename - /// @param include_statistics Whether to include kernel statistics breakdown - /// @return true if export succeeded, false otherwise bool export_json_to_file(const std::string& filename, bool include_statistics = true) const; - /// Enable automatic JSON export on kernel registration - /// @param filename Output filename for auto-export - /// @param include_statistics Whether to include statistics in auto-export - /// @param export_on_every_registration If true, exports after every registration (default). - /// If false, only exports on destruction. void enable_auto_export(const std::string& filename, bool include_statistics = true, bool export_on_every_registration = true); - /// Disable automatic JSON export void disable_auto_export(); - /// Check if auto-export is enabled [[nodiscard]] bool is_auto_export_enabled() const; - /// Merge kernels from another registry into this one - /// @param other Registry to merge from - /// @param priority Priority for merged kernels (default: Normal) - /// @return Number of kernels successfully merged - std::size_t merge_from(const Registry& other, Priority priority = Priority::Normal); - /// Filter kernels in-place by architecture - /// @param gpu_arch Target GPU architecture string (e.g., "gfx942") - /// @return Number of kernels removed std::size_t filter_by_arch(const std::string& gpu_arch); - /// Get singleton instance of the global registry (backward compatible) - /// This is the default registry used when no specific registry is provided + /// Get singleton instance static Registry& instance(); private: - struct RegistryEntry - { - KernelInstancePtr instance; - Priority priority; - }; - - /// Perform auto-export if enabled void perform_auto_export(); - mutable std::mutex mutex_; - std::unordered_map kernels_; - std::string name_; - // Auto-export configuration bool auto_export_enabled_ = false; std::string auto_export_filename_; @@ -179,7 +102,7 @@ class Registry bool auto_export_on_every_registration_ = true; }; -/// Shared pointer type for registries (useful for managing lifetime) +/// Shared pointer type for registries using RegistryPtr = std::shared_ptr; /// Create a new registry instance (factory function) diff --git a/dispatcher/include/ck_tile/dispatcher_conv.hpp b/dispatcher/include/ck_tile/dispatcher_conv.hpp new file mode 100644 index 0000000000..46d14f90f3 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher_conv.hpp @@ -0,0 +1,18 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Grouped Convolution-only dispatcher header -- minimal include for conv operations. + +#pragma once + +// Core (needed by all ops) +#include "ck_tile/dispatcher/base_registry.hpp" +#include "ck_tile/dispatcher/dispatcher_error.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +// Grouped Convolution +#include "ck_tile/dispatcher/grouped_conv_config.hpp" +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp" +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" diff --git a/dispatcher/include/ck_tile/dispatcher_gemm.hpp b/dispatcher/include/ck_tile/dispatcher_gemm.hpp new file mode 100644 index 0000000000..79317c7399 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher_gemm.hpp @@ -0,0 +1,22 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// GEMM-only dispatcher header -- minimal include for GEMM operations. + +#pragma once + +// Core (needed by all ops) +#include "ck_tile/dispatcher/base_registry.hpp" +#include "ck_tile/dispatcher/dispatcher_error.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +// GEMM +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/kernel_config.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/json_export.hpp" +#include "ck_tile/dispatcher/utils.hpp" diff --git a/dispatcher/python/CMakeLists.txt b/dispatcher/python/CMakeLists.txt index e57678952e..71634fa926 100644 --- a/dispatcher/python/CMakeLists.txt +++ b/dispatcher/python/CMakeLists.txt @@ -3,7 +3,7 @@ # This directory contains Python utilities for the dispatcher examples. # The main utility file is ctypes_utils.py which is used by GEMM Python examples. -# Conv Python examples use their own conv_utils.py in the examples directory. +# Grouped conv Python examples use grouped_conv_utils.py in this directory. # No build targets needed - these are pure Python utilities. message(STATUS "Python utilities directory configured (no build targets)") diff --git a/dispatcher/python/README.md b/dispatcher/python/README.md index 9286acbf72..edbc7acc9d 100644 --- a/dispatcher/python/README.md +++ b/dispatcher/python/README.md @@ -4,6 +4,19 @@ This directory contains Python utilities used by the dispatcher examples. ## Contents +### Shared Utilities (used by both GEMM and Grouped Conv) + +- `dispatcher_common.py` - Shared dispatcher infrastructure + - Path helpers (`get_dispatcher_root`, `get_build_dir`, etc.) + - `ValidationResultBase` - Structured validation feedback + - `validate_wave_config`, `validate_warp_tile_config`, `validate_trait_combo` + - `auto_correct_wave`, `auto_correct_trait` - Auto-correction helpers + - `Colors` - Cross-platform ANSI color support + - `print_phase`, `print_success`, `print_error`, `print_info` - Phased output + - `cleanup_generated_kernels` - Cleanup helper + +### GEMM Utilities + - `ctypes_utils.py` - Core ctypes utilities for GEMM Python examples - `KernelConfig` - Kernel configuration dataclass - `setup_gemm_dispatcher()` - Setup dispatcher with auto-correction @@ -11,11 +24,15 @@ This directory contains Python utilities used by the dispatcher examples. - `GemmRunner` - GPU execution helper - Auto-correction and validation utilities -- `conv_utils.py` - Core utilities for Conv Python examples - - `ConvSignature`, `ConvAlgorithm` - Convolution configuration - - `ConvProblem` - Problem definition - - `GpuConvRunner` - GPU execution helper - - `EnhancedConvCodegenRunner` - Kernel codegen utilities +### Grouped Convolution Utilities + +- `grouped_conv_utils.py` - Utilities for grouped convolution + - `GroupedConvValidationResult` - Validation result (extends `ValidationResultBase`) + - `validate_grouped_conv_config` - Validate a grouped conv config + - `auto_correct_grouped_conv_config` - Auto-correct invalid configs + - `get_grouped_conv_default_config` - Get default config for a variant + - `GroupedConvDataType` - Data type enum (FP16, BF16, FP32, FP8, BF8, INT8) + - `format_grouped_conv_summary` - Human-readable config summary ## Usage @@ -36,21 +53,26 @@ from ctypes_utils import ( ) ``` -### Conv Examples - -The Conv Python examples in `dispatcher/examples/conv/python/` import: +### Grouped Conv Usage ```python import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) -from conv_utils import ( - ConvSignature, - ConvAlgorithm, - ConvProblem, - GpuConvRunner, +from grouped_conv_utils import ( + validate_grouped_conv_config, + auto_correct_grouped_conv_config, + get_grouped_conv_default_config, + GroupedConvDataType, ) + +# Get a default config +config = get_grouped_conv_default_config(variant="forward", arch="gfx942") + +# Validate +result = validate_grouped_conv_config(config) +print(f"Valid: {result.is_valid}") ``` ## Requirements diff --git a/dispatcher/python/ctypes_utils.py b/dispatcher/python/ctypes_utils.py index 821fc2b08d..c11aaca835 100644 --- a/dispatcher/python/ctypes_utils.py +++ b/dispatcher/python/ctypes_utils.py @@ -37,6 +37,43 @@ import multiprocessing import time +# ============================================================================= +# GPU Architecture Auto-Detection +# ============================================================================= + +_detected_arch: Optional[str] = None + + +def detect_gpu_arch(fallback: str = "gfx942") -> str: + """ + Auto-detect the GPU architecture by querying rocminfo. + + Caches the result after the first call. Falls back to `fallback` if + detection fails (e.g. no GPU, rocminfo not installed). + """ + global _detected_arch + if _detected_arch is not None: + return _detected_arch + + try: + result = subprocess.run( + ["/opt/rocm/bin/rocminfo"], capture_output=True, text=True, timeout=10 + ) + for line in result.stdout.splitlines(): + stripped = line.strip() + if stripped.startswith("Name:") and "gfx" in stripped: + # Extract e.g. "gfx950" from "Name: gfx950" + name = stripped.split(":", 1)[1].strip() + if name.startswith("gfx") and name[3:].isdigit(): + _detected_arch = name + return _detected_arch + except Exception: + pass + + _detected_arch = fallback + return _detected_arch + + # ============================================================================= # Path Configuration # ============================================================================= @@ -159,9 +196,9 @@ class ValidationResult: def print_result(self, indent: str = " "): """Print validation result.""" if self.is_valid: - print(f"{indent}✓ Configuration valid") + print(f"{indent}OK Configuration valid") else: - print(f"{indent}⚠ Configuration has issues:") + print(f"{indent}WARNING Configuration has issues:") for err in self.errors: print(f"{indent} - {err}") @@ -300,7 +337,7 @@ def auto_correct_kernel_config( # Check each fix and describe what changed if "scheduler" in fixes and fixes["scheduler"] != config.scheduler: corrections.append( - f"Scheduler: {config.scheduler} → {fixes['scheduler']} " + f"Scheduler: {config.scheduler} -> {fixes['scheduler']} " f"('{config.scheduler}' not supported with pipeline={config.pipeline}, epilogue={config.epilogue})" ) @@ -309,7 +346,7 @@ def auto_correct_kernel_config( new_wave = f"[{fixes.get('wave_m', config.wave_m)}, {fixes.get('wave_n', config.wave_n)}, {fixes.get('wave_k', config.wave_k)}]" if old_wave != new_wave: corrections.append( - f"Wave config: {old_wave} → {new_wave} " + f"Wave config: {old_wave} -> {new_wave} " f"(original not supported on {config.gfx_arch})" ) @@ -318,7 +355,7 @@ def auto_correct_kernel_config( new_warp = f"[{fixes.get('warp_m', config.warp_m)}, {fixes.get('warp_n', config.warp_n)}, {fixes.get('warp_k', config.warp_k)}]" if old_warp != new_warp: corrections.append( - f"Warp tile: {old_warp} → {new_warp} " + f"Warp tile: {old_warp} -> {new_warp} " f"(original not supported for {config.dtype_a} on {config.gfx_arch})" ) @@ -386,13 +423,13 @@ def print_auto_correction( indent: Indentation for output """ if not corrections: - print(f"{indent}✓ Configuration valid - no corrections needed") + print(f"{indent}OK Configuration valid - no corrections needed") return - print(f"\n{indent}⚠ AUTO-CORRECTION APPLIED:") + print(f"\n{indent}WARNING AUTO-CORRECTION APPLIED:") print(f"{indent}" + "-" * 50) for correction in corrections: - print(f"{indent} • {correction}") + print(f"{indent} - {correction}") print(f"{indent}" + "-" * 50) print() @@ -976,6 +1013,226 @@ def _run_codegen_subprocess(args: Dict[str, Any]) -> CodegenResult: ) +def _run_hipcc_subprocess(args: dict) -> Tuple[bool, Optional[Path], str]: + """Module-level function to run hipcc compilation in parallel.""" + import subprocess + from pathlib import Path + + compile_cmd = args["compile_cmd"] + link_cmd = args["link_cmd"] + lib_path = Path(args["lib_path"]) + + try: + res_c = subprocess.run(compile_cmd, capture_output=True, text=True, timeout=300) + if res_c.returncode != 0: + return False, None, f"Compile failed: {res_c.stderr[:200]}" + + res_l = subprocess.run(link_cmd, capture_output=True, text=True, timeout=300) + if res_l.returncode != 0: + return False, None, f"Link failed: {res_l.stderr[:200]}" + + return True, lib_path, "" + except subprocess.TimeoutExpired: + return False, None, "Timeout" + except Exception as e: + return False, None, str(e) + + +def _generate_single_kernel_subprocess(args: dict) -> Tuple[bool, Optional[str], str]: + """Module-level function: generate ONE kernel .hpp via --config JSON file. + + Used by setup_multiple_gemm_dispatchers for per-config parallel codegen. + Returns (success, header_path_or_None, error_msg). + """ + import subprocess + import json + import tempfile + import os + from pathlib import Path + + try: + out_dir = Path(args["output_dir"]) + out_dir.mkdir(parents=True, exist_ok=True) + + # Write the single-config JSON to a temp file + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(args["tile_config_json"], f) + config_file = f.name + + cmd = [ + args["python"], + str(args["codegen_script"]), + "--output-dir", + str(out_dir), + "--datatype", + args["dtype"], + "--layout", + args["layout"], + "--gpu-target", + args["gpu_target"], + "--config", + config_file, + "--variants", + "standard", + ] + + res = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + os.unlink(config_file) + + if res.returncode != 0: + return False, None, f"Codegen failed: {res.stderr[:200]}" + + # Find the generated .hpp using the expected name pattern + pattern = args["hpp_glob_pattern"] + matches = sorted(out_dir.glob(pattern)) + if matches: + return True, str(matches[0]), "" + else: + return False, None, f"No .hpp matching {pattern} after codegen" + + except Exception as e: + return False, None, str(e) + + +def _parse_triplet(text: str) -> Optional[Tuple[int, int, int]]: + parts = text.split("x") + if len(parts) != 3: + return None + try: + return (int(parts[0]), int(parts[1]), int(parts[2])) + except ValueError: + return None + + +def _parse_gemm_header_metadata(header: Path) -> Optional[Dict[str, Any]]: + """ + Parse GEMM header name into configuration metadata. + + Expected stem format: + gemm_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler} + _{pad_m}_{pad_n}_{pad_k}_{persistent} + _{tile_m}x{tile_n}x{tile_k}_{wave_m}x{wave_n}x{wave_k}_{warp_m}x{warp_n}x{warp_k} + """ + parts = header.stem.split("_") + if len(parts) < 13 or parts[0] != "gemm": + return None + + tile = _parse_triplet(parts[10]) + wave = _parse_triplet(parts[11]) + warp = _parse_triplet(parts[12]) + if tile is None or wave is None or warp is None: + return None + + def _as_bool(v: str) -> bool: + return v.lower() == "true" + + return { + "dtype": parts[1], + "layout": parts[2], + "pipeline": parts[3], + "epilogue": parts[4], + "scheduler": parts[5], + "pad_m": _as_bool(parts[6]), + "pad_n": _as_bool(parts[7]), + "pad_k": _as_bool(parts[8]), + "persistent": _as_bool(parts[9]), + "tile": tile, + "wave": wave, + "warp": warp, + } + + +def _generate_arch_valid_gemm_headers( + python_exe: str, + codegen_script: Path, + output_dir: Path, + dtype: str, + layout: str, + gpu_target: str, + variant: str = "standard", +) -> Tuple[bool, List[Path], str]: + """Generate (or reuse) an arch-filtered kernel catalog for fallback selection.""" + output_dir.mkdir(parents=True, exist_ok=True) + pattern = f"gemm_{dtype}_{layout}_*.hpp" + existing = sorted(output_dir.glob(pattern)) + if existing: + return True, existing, "" + + cmd = [ + python_exe, + str(codegen_script), + "--output-dir", + str(output_dir), + "--datatype", + dtype, + "--layout", + layout, + "--gpu-target", + gpu_target, + "--variants", + variant, + ] + res = subprocess.run(cmd, capture_output=True, text=True, timeout=600) + if res.returncode != 0: + err = (res.stderr or res.stdout or "").strip()[:500] + return False, [], f"Catalog codegen failed: {err}" + + generated = sorted(output_dir.glob(pattern)) + if not generated: + return False, [], "Catalog codegen produced no GEMM headers" + return True, generated, "" + + +def _select_best_arch_valid_gemm_header( + config: "KernelConfig", + headers: List[Path], +) -> Tuple[Optional[Path], Optional[Dict[str, Any]]]: + """Choose nearest arch-valid header for a requested GEMM config.""" + best: Optional[Path] = None + best_meta: Optional[Dict[str, Any]] = None + best_score: Optional[Tuple[int, int, int, int, int, int]] = None + + for h in headers: + meta = _parse_gemm_header_metadata(h) + if meta is None: + continue + if meta["dtype"] != config.dtype_a or meta["layout"] != config.layout: + continue + + tile = meta["tile"] + wave = meta["wave"] + warp = meta["warp"] + tile_delta = ( + abs(tile[0] - config.tile_m) + + abs(tile[1] - config.tile_n) + + abs(tile[2] - config.tile_k) + ) + wave_delta = ( + abs(wave[0] - config.wave_m) + + abs(wave[1] - config.wave_n) + + abs(wave[2] - config.wave_k) + ) + warp_delta = ( + abs(warp[0] - config.warp_m) + + abs(warp[1] - config.warp_n) + + abs(warp[2] - config.warp_k) + ) + score = ( + 0 if meta["pipeline"] == config.pipeline else 1, + 0 if meta["scheduler"] == config.scheduler else 1, + 0 if meta["epilogue"] == config.epilogue else 1, + tile_delta, + wave_delta, + warp_delta, + ) + if best_score is None or score < best_score: + best_score = score + best = h + best_meta = meta + + return best, best_meta + + # ============================================================================= # Preshuffle Utilities # ============================================================================= @@ -1319,7 +1576,7 @@ class CodegenRunner: result = future.result() results.append(result) if verbose: - status = "✓" if result.success else "✗" + status = "OK" if result.success else "FAIL" print( f" {status} {variant}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" ) @@ -1337,7 +1594,7 @@ class CodegenRunner: ) ) if verbose: - print(f" ✗ {variant}: FAILED - {e}") + print(f" FAIL {variant}: FAILED - {e}") total_time = time.time() - start_total if verbose: @@ -1399,7 +1656,7 @@ class CodegenRunner: result = future.result() results.append(result) if verbose: - status = "✓" if result.success else "✗" + status = "OK" if result.success else "FAIL" print( f" {status} {tile_str}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" ) @@ -1417,7 +1674,7 @@ class CodegenRunner: ) ) if verbose: - print(f" ✗ {tile_str}: FAILED - {e}") + print(f" FAIL {tile_str}: FAILED - {e}") total_time = time.time() - start_total if verbose: @@ -1481,7 +1738,7 @@ class CodegenRunner: result = future.result() results.append(result) if verbose: - status = "✓" if result.success else "✗" + status = "OK" if result.success else "FAIL" print( f" {status} {variant}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" ) @@ -1499,7 +1756,7 @@ class CodegenRunner: ) ) if verbose: - print(f" ✗ {variant}: FAILED - {e}") + print(f" FAIL {variant}: FAILED - {e}") total_time = time.time() - start_total if verbose: @@ -1767,7 +2024,7 @@ class CodegenRunner: link_cmd, capture_output=True, text=True, timeout=300 ) if result.returncode == 0: - print(f" ✓ Library rebuilt: {lib_path.name}") + print(f" OK Library rebuilt: {lib_path.name}") # Clean up object file obj_file.unlink(missing_ok=True) return lib_path @@ -1781,6 +2038,105 @@ class CodegenRunner: print(f" Build error: {e}") return None + def build_libraries_parallel( + self, configs_and_headers: List[Tuple[KernelConfig, Path]], verbose: bool = True + ) -> List[Optional[Path]]: + """ + Build multiple libraries in parallel using ProcessPoolExecutor. + Returns a list of library paths (or None if a build failed) in the same order. + """ + import time + from concurrent.futures import ProcessPoolExecutor, as_completed + + start_time = time.time() + build_dir = get_build_dir() + root = get_dispatcher_root() + ck_root = root.parent + ctypes_source = root / "bindings/ctypes/gemm_ctypes_lib.cpp" + static_lib = build_dir / "libck_tile_dispatcher.a" + + if not ctypes_source.exists() or not static_lib.exists(): + if verbose: + print(" Required source or static library missing for parallel build.") + return [None] * len(configs_and_headers) + + args_list = [] + for config, kernel_header in configs_and_headers: + lib_name = f"libdispatcher_gemm_{config.dtype_a}_{config.layout}_{config.tile_str}_{config.pipeline}.so" + lib_path = build_dir / "examples" / lib_name + obj_file = lib_path.with_suffix(".o") + + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-c", + "-fPIC", + "-O3", + f"-I{root / 'include'}", + f"-I{ck_root / 'include'}", + f"-I{ck_root}", + f"-I{root / 'build/generated_kernels'}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", + f"-include{kernel_header}", + "-D__HIP_PLATFORM_AMD__", + f"--offload-arch={config.gfx_arch}", + f'-DGFX_ARCH="{config.gfx_arch}"', + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(obj_file), + ] + + link_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + f"--offload-arch={config.gfx_arch}", + "--hip-link", + str(obj_file), + str(static_lib), + "-o", + str(lib_path), + ] + + args_list.append( + { + "compile_cmd": compile_cmd, + "link_cmd": link_cmd, + "lib_path": str(lib_path), + "config_name": f"{config.dtype_a}_{config.layout}_{config.tile_str}", + } + ) + + if verbose: + print( + f"Building {len(args_list)} libraries in parallel (workers={self.max_workers})..." + ) + + results_map = {} + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(_run_hipcc_subprocess, args): i + for i, args in enumerate(args_list) + } + for future in as_completed(futures): + idx = futures[future] + success, lib_path, err = future.result() + results_map[idx] = Path(lib_path) if success else None + if verbose: + status = "OK" if success else f"FAIL ({err})" + print( + f" {status} {Path(lib_path).name if success else args_list[idx]['config_name']}" + ) + + if verbose: + elapsed = time.time() - start_time + print(f"Parallel build finished in {elapsed:.2f}s") + + return [results_map[i] for i in range(len(configs_and_headers))] + def generate_preselected( self, preset: str = "fp16_rcr_essential", output_dir: Optional[Path] = None ) -> CodegenResult: @@ -1933,6 +2289,28 @@ class Registry: """Bind to a loaded dispatcher library.""" self._lib = lib + def build( + self, + verbose: bool = False, + max_workers: Optional[int] = None, + ) -> List["GemmSetupResult"]: + """Parallel JIT compile all kernels in this registry. + + Args: + verbose: Print progress during build. + max_workers: Max parallel codegen/compile processes (default: cpu_count capped at 8). + + Returns a GemmSetupResult per registered kernel (same order as get_kernels()). + """ + if not self._kernels: + return [] + return setup_multiple_gemm_dispatchers( + self._kernels, + registry_name=self._name, + verbose=verbose, + max_workers=max_workers, + ) + def __repr__(self) -> str: return f"Registry(name='{self._name}', kernels={self.kernel_count})" @@ -2109,7 +2487,7 @@ def setup_gemm_dispatcher( log(" Validating config...") validation = validate_kernel_config(config) if not validation.is_valid: - log(" ⚠ Auto-correcting configuration...") + log(" WARNING Auto-correcting configuration...") config, was_modified, corrections = auto_correct_kernel_config( config, verbose=verbose ) @@ -2128,13 +2506,13 @@ def setup_gemm_dispatcher( codegen_result = codegen.generate_from_config(config) if not codegen_result.success: - log(" ⚠ Kernel generation: using existing") + log(" WARNING Kernel generation: using existing") # Step 3: Find matching kernel header kernel_header = find_matching_kernel_header(config) result.kernel_header = kernel_header if not kernel_header: - log(" ⚠ No matching kernel header found") + log(" WARNING No matching kernel header found") # Step 4: Load library log(" Loading library...") @@ -2188,11 +2566,11 @@ def setup_gemm_dispatcher( result.error = "Failed to load rebuilt library" return result result.lib = lib - log(f" ✓ Rebuilt library: {lib.get_kernel_name()}") + log(f" OK Rebuilt library: {lib.get_kernel_name()}") else: - log(" ⚠ Rebuild failed, using existing library") + log(" WARNING Rebuild failed, using existing library") else: - log(" ⚠ No kernel header found for config, using existing library") + log(" WARNING No kernel header found for config, using existing library") # Step 5: Create registry and dispatcher log(" Creating registry and dispatcher...") @@ -2203,12 +2581,305 @@ def setup_gemm_dispatcher( dispatcher = Dispatcher(registry=registry, lib=lib) result.dispatcher = dispatcher - log(f" ✓ Ready: {lib.get_kernel_name()}") + log(f" OK Ready: {lib.get_kernel_name()}") result.success = True return result +def setup_multiple_gemm_dispatchers( + configs: List[KernelConfig], + registry_name: str = "gemm_registry", + verbose: bool = True, + max_workers: Optional[int] = None, +) -> List[GemmSetupResult]: + """ + Setup multiple GEMM dispatchers in parallel. + + Pipeline: + 1. Validate + auto-correct each config + 2. Parallel codegen: generate .hpp for each config via --config JSON + 3. Parallel hipcc: compile each .hpp -> .so + 4. Load + wire up each .so into a GemmSetupResult + + Each config gets its own .so, so different tile sizes can coexist. + + Args: + max_workers: Max parallel processes for codegen/compile (default: cpu_count capped at 8). + """ + import sys + + results = [GemmSetupResult(success=False, config=c) for c in configs] + max_workers = max_workers or min(multiprocessing.cpu_count(), 8) + + # -- Step 1: Validate & correct --------------------------------------- + valid_configs = [] + for i, c in enumerate(configs): + val = validate_kernel_config(c) + if not val.is_valid: + c, modified, corrections = auto_correct_kernel_config(c, verbose=False) + results[i].config = c + results[i].corrections = corrections + valid_configs.append(c) + + # -- Step 2: Parallel codegen (one --config JSON per config) ---------- + codegen_script = get_codegen_path() + output_dir = get_generated_kernels_dir() + + codegen_args = [] + for c in valid_configs: + tile_str = c.tile_str + wave_str = f"{c.wave_m}x{c.wave_n}x{c.wave_k}" + warp_str = f"{c.warp_m}x{c.warp_n}x{c.warp_k}" + + tile_config_json = { + "tile_config": { + "tile_m": [c.tile_m], + "tile_n": [c.tile_n], + "tile_k": [c.tile_k], + "warp_m": [c.wave_m], + "warp_n": [c.wave_n], + "warp_k": [c.wave_k], + "warp_tile_m": [c.warp_m], + "warp_tile_n": [c.warp_n], + "warp_tile_k": [c.warp_k], + }, + "trait_config": { + "pipeline": [c.pipeline], + "epilogue": [c.epilogue], + "scheduler": [c.scheduler], + "pad_m": [c.pad_m], + "pad_n": [c.pad_n], + "pad_k": [c.pad_k], + "persistent": [False], + }, + } + + hpp_pattern = ( + f"gemm_{c.dtype_a}_{c.layout}_{c.pipeline}_{c.epilogue}_{c.scheduler}" + f"_*_{tile_str}_{wave_str}_{warp_str}.hpp" + ) + + codegen_args.append( + { + "python": sys.executable, + "codegen_script": str(codegen_script), + "output_dir": str(output_dir), + "dtype": c.dtype_a, + "layout": c.layout, + "gpu_target": c.gfx_arch, + "tile_config_json": tile_config_json, + "hpp_glob_pattern": hpp_pattern, + } + ) + + if verbose: + print( + f"Generating {len(codegen_args)} kernel headers in parallel (workers={max_workers})..." + ) + + headers: List[Optional[Path]] = [None] * len(valid_configs) + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(_generate_single_kernel_subprocess, a): i + for i, a in enumerate(codegen_args) + } + for future in as_completed(futures): + idx = futures[future] + ok, hdr_str, err = future.result() + if ok and hdr_str: + headers[idx] = Path(hdr_str) + results[idx].kernel_header = Path(hdr_str) + if verbose: + print( + f" OK [{idx}] {valid_configs[idx].tile_str}: {Path(hdr_str).name}" + ) + else: + results[idx].error = f"Codegen: {err}" + if verbose: + print(f" FAIL [{idx}] {valid_configs[idx].tile_str}: {err}") + + # For configs rejected by arch filter, map to nearest arch-valid header. + fallback_needed = [i for i, h in enumerate(headers) if h is None] + if fallback_needed: + if verbose: + print( + f"Resolving {len(fallback_needed)} configs via arch-valid GEMM catalog..." + ) + + catalog_cache: Dict[Tuple[str, str, str, str], List[Path]] = {} + for i in fallback_needed: + c = valid_configs[i] + key = (c.gfx_arch, c.dtype_a, c.layout, c.variant) + if key not in catalog_cache: + catalog_dir = ( + output_dir + / "_arch_valid_catalog" + / (f"{c.gfx_arch}_{c.dtype_a}_{c.layout}_{c.variant}") + ) + ok, catalog_headers, err = _generate_arch_valid_gemm_headers( + python_exe=sys.executable, + codegen_script=codegen_script, + output_dir=catalog_dir, + dtype=c.dtype_a, + layout=c.layout, + gpu_target=c.gfx_arch, + variant=c.variant, + ) + if not ok: + catalog_headers = [] + if verbose: + print(f" FAIL [{i}] catalog generation: {err}") + catalog_cache[key] = catalog_headers + + chosen, meta = _select_best_arch_valid_gemm_header(c, catalog_cache[key]) + if chosen is None or meta is None: + continue + + headers[i] = chosen + results[i].kernel_header = chosen + results[i].error = "" + + # Keep Python-side config aligned with the selected kernel header. + valid_configs[i].pipeline = str(meta["pipeline"]) + valid_configs[i].epilogue = str(meta["epilogue"]) + valid_configs[i].scheduler = str(meta["scheduler"]) + valid_configs[i].pad_m = bool(meta["pad_m"]) + valid_configs[i].pad_n = bool(meta["pad_n"]) + valid_configs[i].pad_k = bool(meta["pad_k"]) + valid_configs[i].tile_m = int(meta["tile"][0]) + valid_configs[i].tile_n = int(meta["tile"][1]) + valid_configs[i].tile_k = int(meta["tile"][2]) + valid_configs[i].wave_m = int(meta["wave"][0]) + valid_configs[i].wave_n = int(meta["wave"][1]) + valid_configs[i].wave_k = int(meta["wave"][2]) + valid_configs[i].warp_m = int(meta["warp"][0]) + valid_configs[i].warp_n = int(meta["warp"][1]) + valid_configs[i].warp_k = int(meta["warp"][2]) + results[i].config = valid_configs[i] + + if verbose: + print(f" INFO [{i}] mapped to arch-valid header: {chosen.name}") + + # -- Step 3: Parallel hipcc compilation ------------------------------- + root = get_dispatcher_root() + ck_root = root.parent + build_dir = get_build_dir() + ctypes_source = root / "bindings" / "ctypes" / "gemm_ctypes_lib.cpp" + static_lib = build_dir / "libck_tile_dispatcher.a" + + if not ctypes_source.exists() or not static_lib.exists(): + for i in range(len(valid_configs)): + if results[i].error == "": + results[ + i + ].error = "Missing ctypes source or static library for compilation" + return results + + compile_jobs = [] + compile_index_map = {} + for i, c in enumerate(valid_configs): + hdr = headers[i] + if hdr is None: + continue + + lib_name = ( + f"libdispatcher_gemm_{c.dtype_a}_{c.layout}_{c.tile_str}_{c.pipeline}.so" + ) + lib_path = build_dir / "examples" / lib_name + obj_file = lib_path.with_suffix(".o") + + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-c", + "-fPIC", + "-O3", + f"-I{root / 'include'}", + f"-I{ck_root / 'include'}", + f"-I{ck_root}", + f"-I{str(output_dir)}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", + f"-include{hdr}", + "-D__HIP_PLATFORM_AMD__", + f"--offload-arch={c.gfx_arch}", + f'-DGFX_ARCH="{c.gfx_arch}"', + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(obj_file), + ] + link_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + f"--offload-arch={c.gfx_arch}", + "--hip-link", + str(obj_file), + str(static_lib), + "-o", + str(lib_path), + ] + + compile_index_map[len(compile_jobs)] = i + compile_jobs.append( + { + "compile_cmd": compile_cmd, + "link_cmd": link_cmd, + "lib_path": str(lib_path), + } + ) + + if verbose and compile_jobs: + print( + f"Compiling {len(compile_jobs)} libraries in parallel (workers={max_workers})..." + ) + + lib_paths: Dict[int, Optional[Path]] = {} + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(_run_hipcc_subprocess, job): j + for j, job in enumerate(compile_jobs) + } + for future in as_completed(futures): + j = futures[future] + i = compile_index_map[j] + ok, lp, err = future.result() + if ok and lp: + lib_paths[i] = Path(lp) + if verbose: + print(f" OK [{i}] {valid_configs[i].tile_str}: {Path(lp).name}") + else: + results[i].error = f"Compile: {err}" + if verbose: + print(f" FAIL [{i}] {valid_configs[i].tile_str}: {err}") + + # -- Step 4: Load libraries and create dispatchers -------------------- + for i, c in enumerate(valid_configs): + lp = lib_paths.get(i) + if lp is None: + continue + + lib = DispatcherLib.load(lp) + if lib is not None and lib.initialize(): + results[i].lib = lib + reg = Registry(name=f"{registry_name}_{i}", lib=lib) + reg.register_kernel(c) + results[i].registry = reg + results[i].dispatcher = Dispatcher(registry=reg, lib=lib) + results[i].success = True + else: + results[i].error = "Failed to load compiled library" + + if verbose: + ok_count = sum(1 for r in results if r.success) + print(f"Setup complete: {ok_count}/{len(results)} dispatchers ready") + + return results + + def cleanup_gemm(): """ Cleanup function to call after running GEMM examples. diff --git a/dispatcher/python/dispatcher_common.py b/dispatcher/python/dispatcher_common.py new file mode 100644 index 0000000000..a19ecbdb49 --- /dev/null +++ b/dispatcher/python/dispatcher_common.py @@ -0,0 +1,372 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Shared Python dispatcher utilities for GEMM and grouped convolution. + +Extracted from ctypes_utils.py (GEMM) + compile_grouped_conv_examples.py (grouped conv). +Both ctypes_utils.py and grouped_conv_utils.py import from here to +eliminate duplication. + +Best-of-both: + - Validation and auto-correction return typed objects (GEMM pattern) + - Colors class with cross-platform ANSI handling (conv pattern) + - Phased output helpers (conv pattern) + - logging module instead of bare print() (shared improvement) +""" + +import logging +import shutil +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +log = logging.getLogger(__name__) + + +# ============================================================================ +# Path Configuration +# ============================================================================ + + +def get_dispatcher_root() -> Path: + """Get the dispatcher root directory (parent of python/).""" + return Path(__file__).parent.parent + + +def get_ck_root() -> Path: + """Get the CK root directory (parent of dispatcher/).""" + return get_dispatcher_root().parent + + +def get_build_dir() -> Path: + """Get the build directory.""" + return get_dispatcher_root() / "build" + + +def get_generated_kernels_dir() -> Path: + """Get the generated kernels directory.""" + return get_build_dir() / "generated_kernels" + + +def get_codegen_dir() -> Path: + """Get the codegen scripts directory.""" + return get_dispatcher_root() / "codegen" + + +# ============================================================================ +# Architecture Filter Data +# ============================================================================ + +_arch_data_cache: Optional[Dict[str, Any]] = None + + +def detect_gpu_arch(fallback: str = "gfx942") -> str: + """Detect the GPU architecture from rocminfo. Falls back to the given default.""" + import subprocess + + try: + out = subprocess.check_output( + ["rocminfo"], text=True, stderr=subprocess.DEVNULL + ) + for line in out.splitlines(): + if "Name:" in line and "gfx" in line: + return line.split()[-1].strip() + except Exception: + pass + return fallback + + +def get_arch_filter_data() -> Dict[str, Any]: + """Load arch filter data from arch_specs_generated if available. + + Returns dict with keys: trait_unsupported, warp_combos, + warp_tile_combos, supported_archs. + """ + global _arch_data_cache + if _arch_data_cache is not None: + return _arch_data_cache + + codegen_dir = get_dispatcher_root() / "codegen" + sys.path.insert(0, str(codegen_dir)) + + try: + from arch_specs_generated import ( + TRAIT_UNSUPPORTED_COMBINATIONS, + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + get_supported_archs, + ) + + _arch_data_cache = { + "trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS, + "warp_combos": WARP_SUPPORTED_COMBINATIONS, + "warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS, + "supported_archs": get_supported_archs(), + } + except ImportError: + _arch_data_cache = { + "trait_unsupported": { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + }, + "warp_combos": { + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + }, + "warp_tile_combos": { + "gfx942": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]}, + "gfx90a": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]}, + }, + "supported_archs": ["gfx90a", "gfx942", "gfx950"], + } + + return _arch_data_cache + + +# ============================================================================ +# Validation Result +# ============================================================================ + + +@dataclass +class ValidationResultBase: + """Result of kernel config validation (shared base for GEMM and conv).""" + + is_valid: bool + errors: List[str] = field(default_factory=list) + warnings: List[str] = field(default_factory=list) + suggested_fixes: Dict[str, Any] = field(default_factory=dict) + + def print_result(self, indent: str = " "): + if self.is_valid: + print(f"{indent}OK Configuration valid") + else: + print(f"{indent}WARNING Configuration has issues:") + for err in self.errors: + print(f"{indent} - {err}") + if self.warnings: + for warn in self.warnings: + print(f"{indent} Warning: {warn}") + if self.suggested_fixes: + print(f"{indent} Suggested fixes:") + for key, val in self.suggested_fixes.items(): + print(f"{indent} {key}: {val}") + + +# ============================================================================ +# Validation Helpers +# ============================================================================ + + +def validate_wave_config(wave_cfg: List[int], arch: str) -> Tuple[bool, str]: + """Validate a [wave_m, wave_n, wave_k] config for *arch*. + + Returns (is_valid, error_message). Empty string on success. + """ + data = get_arch_filter_data() + valid_waves = data["warp_combos"].get(arch, [[2, 2, 1]]) + if wave_cfg in valid_waves: + return True, "" + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in valid_waves) + return ( + False, + f"Unsupported wave configuration {wave_cfg} for {arch}. " + f"Valid wave configs: {valid_str}", + ) + + +def validate_warp_tile_config( + warp_cfg: List[int], arch: str, dtype: str +) -> Tuple[bool, str]: + """Validate a [warp_m, warp_n, warp_k] config for *arch*/*dtype*. + + Returns (is_valid, error_message). Empty string on success. + """ + data = get_arch_filter_data() + acc = "int32" if dtype == "int8" else "fp32" + dtype_key = f"{dtype}_{dtype}_{acc}" + valid_tiles = ( + data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + if warp_cfg in valid_tiles: + return True, "" + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in valid_tiles[:5]) + return ( + False, + f"Unsupported warp tile {warp_cfg} for {arch}/{dtype}. " + f"Valid warp tiles: {valid_str}", + ) + + +def validate_trait_combo( + pipeline: str, epilogue: str, scheduler: str +) -> Tuple[bool, str]: + """Validate a (pipeline, epilogue, scheduler) combination. + + Returns (is_valid, error_message). Empty string on success. + """ + data = get_arch_filter_data() + combo = (pipeline, epilogue, scheduler) + if combo in data["trait_unsupported"]: + return ( + False, + f"Unsupported trait combination: pipeline={pipeline}, " + f"epilogue={epilogue}, scheduler={scheduler}", + ) + return True, "" + + +# ============================================================================ +# Auto-Correction Helpers +# ============================================================================ + + +def auto_correct_wave(wave_cfg: List[int], arch: str) -> List[int]: + """Return the first valid wave config for *arch*. + + If *wave_cfg* is already valid, returns it unchanged. + """ + data = get_arch_filter_data() + valid_waves = data["warp_combos"].get(arch, [[2, 2, 1]]) + if wave_cfg in valid_waves: + return wave_cfg + return valid_waves[0] if valid_waves else [2, 2, 1] + + +def auto_correct_trait(pipeline: str, scheduler: str) -> Tuple[str, str]: + """Return a corrected (pipeline, scheduler) pair. + + If the compute pipeline doesn't support interwave, switch to intrawave. + """ + data = get_arch_filter_data() + for epilogue in ("cshuffle", "default"): + if (pipeline, epilogue, scheduler) in data["trait_unsupported"]: + return pipeline, "intrawave" + return pipeline, scheduler + + +# ============================================================================ +# Colors (adopted from compile_grouped_conv_examples.py -- cross-platform) +# ============================================================================ + + +class Colors: + """Cross-platform ANSI color support. + + Respects sys.platform (no ANSI on Windows) and isatty() check so + piped/redirected output stays clean. + """ + + _GREEN = "\033[0;32m" + _YELLOW = "\033[1;33m" + _RED = "\033[0;31m" + _CYAN = "\033[0;36m" + _BOLD = "\033[1m" + _NC = "\033[0m" + + @classmethod + def _use_color(cls) -> bool: + return ( + sys.platform != "win32" + and hasattr(sys.stdout, "isatty") + and sys.stdout.isatty() + ) + + @classmethod + def green(cls, text: str) -> str: + if cls._use_color(): + return f"{cls._GREEN}{text}{cls._NC}" + return text + + @classmethod + def red(cls, text: str) -> str: + if cls._use_color(): + return f"{cls._RED}{text}{cls._NC}" + return text + + @classmethod + def yellow(cls, text: str) -> str: + if cls._use_color(): + return f"{cls._YELLOW}{text}{cls._NC}" + return text + + @classmethod + def cyan(cls, text: str) -> str: + if cls._use_color(): + return f"{cls._CYAN}{text}{cls._NC}" + return text + + @classmethod + def bold(cls, text: str) -> str: + if cls._use_color(): + return f"{cls._BOLD}{text}{cls._NC}" + return text + + +# ============================================================================ +# Phased Output Helpers +# ============================================================================ + + +def print_phase(number: int, description: str) -> None: + """Print a phase header (e.g. 'Phase 1: Codegen').""" + print(f"\n{'=' * 60}") + print(f" Phase {number}: {description}") + print(f"{'=' * 60}") + + +def print_success(message: str) -> None: + """Print a success message.""" + print(f" OK {Colors.green(message)}") + + +def print_error(message: str) -> None: + """Print an error message.""" + print(f" FAIL {Colors.red(message)}") + + +def print_info(message: str) -> None: + """Print an info message.""" + print(f" {Colors.cyan(message)}") + + +# ============================================================================ +# Cleanup Helpers +# ============================================================================ + + +def cleanup_generated_kernels(gen_dir: Optional[Path] = None) -> None: + """Remove generated kernel directory if it exists.""" + if gen_dir is None: + gen_dir = get_generated_kernels_dir() + if gen_dir.exists(): + shutil.rmtree(gen_dir, ignore_errors=True) + log.info("Cleaned up generated kernels at %s", gen_dir) + + +# ============================================================================ +# Tool Helpers +# ============================================================================ + + +def find_hipcc() -> Optional[str]: + """Find the hipcc compiler.""" + import os + + candidates = [ + os.environ.get("HIPCC"), + "/opt/rocm/bin/hipcc", + shutil.which("hipcc"), + ] + for path in candidates: + if path and os.path.isfile(path): + return path + return None diff --git a/dispatcher/python/grouped_conv_utils.py b/dispatcher/python/grouped_conv_utils.py new file mode 100644 index 0000000000..cd6ef5647c --- /dev/null +++ b/dispatcher/python/grouped_conv_utils.py @@ -0,0 +1,1806 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Grouped Convolution Dispatcher Utilities + +Typed Python API for grouped convolution kernels, matching the patterns from +the old conv_utils.py and the GEMM ctypes_utils.py. + +Classes: + GroupedConvKernelConfig - Kernel configuration (tile, wave, pipeline, arch) + GroupedConvProblem - Runtime problem specification (N,C,K,H,W,etc.) + GroupedConvProblemC - ctypes struct matching C++ ConvProblemC + GroupedConvDispatcherLib - Wrapper for libdispatcher_conv_lib.so + GpuGroupedConvRunner - High-level GPU execution runner + GroupedConvResult - Result of GPU execution (output, time, tflops) + GroupedConvRegistry - Collection of kernel configs with JSON export + +Usage: + from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GpuGroupedConvRunner, + ) + + config = GroupedConvKernelConfig(variant="forward", ndim_spatial=2) + problem = GroupedConvProblem(N=1, C=64, K=128, Hi=28, Wi=28, Y=3, X=3, + stride_h=1, pad_h=1, direction="forward") + runner = GpuGroupedConvRunner() + if runner.is_available(): + result = runner.run(input_np, weight_np, problem) + print(f"Time: {result.time_ms:.4f} ms, TFLOPS: {result.tflops:.2f}") +""" + +import ctypes +import json +import copy +import subprocess +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from dispatcher_common import ( + ValidationResultBase, + auto_correct_trait, + auto_correct_wave, + get_arch_filter_data, + validate_trait_combo, + validate_wave_config, + validate_warp_tile_config, +) + + +# ============================================================================= +# Constants +# ============================================================================= + +VALID_VARIANTS = ("forward", "bwd_data", "bwd_weight") +VALID_NDIM_SPATIAL = (1, 2, 3) +BACKWARD_VARIANTS = ("bwd_data", "bwd_weight") +BACKWARD_PIPELINES = ("compv3", "mem") + +VARIANT_ALIASES = { + "2d_fwd": "forward", + "2d_bwdd": "bwd_data", + "2d_bwdw": "bwd_weight", + "fwd": "forward", + "bwdd": "bwd_data", + "bwdw": "bwd_weight", +} + +DIRECTION_MAP = {"forward": 0, "bwd_data": 1, "bwd_weight": 2} + + +def _resolve_variant(v: str) -> str: + return VARIANT_ALIASES.get(v, v) + + +# ============================================================================= +# GroupedConvDataType +# ============================================================================= + + +class GroupedConvDataType(Enum): + FP16 = "fp16" + BF16 = "bf16" + FP32 = "fp32" + FP8 = "fp8" + BF8 = "bf8" + INT8 = "int8" + + +# ============================================================================= +# GroupedConvKernelConfig +# ============================================================================= + + +@dataclass +class GroupedConvKernelConfig: + """Complete kernel configuration for grouped convolution. + + Captures all parameters needed to identify and run a specific kernel. + Mirrors the C++ GroupedConvSignature + GroupedConvAlgorithm. + """ + + # What: signature + variant: str = "forward" + ndim_spatial: int = 2 + dtype: str = "fp16" + layout: str = "nhwgc" + arch: str = "gfx942" + + # How: algorithm - tile shape + tile_m: int = 1 + tile_n: int = 128 + tile_k: int = 128 + + # How: wave config + wave_m: int = 2 + wave_n: int = 2 + wave_k: int = 1 + + # How: warp tile + warp_tile_m: int = 32 + warp_tile_n: int = 32 + warp_tile_k: int = 16 + + # How: pipeline traits + pipeline: str = "compv4" + epilogue: str = "cshuffle" + scheduler: str = "intrawave" + + # ConvConfigBase parity fields + vector_size_a: int = 4 + vector_size_b: int = 8 + vector_size_c: int = 8 + block_per_cu: int = 1 + num_wave_groups: int = 1 + num_groups_to_merge: int = 1 + + # Padding (enables arbitrary problem sizes) + pad_m: bool = True + pad_n: bool = True + pad_k: bool = True + + def __post_init__(self): + self.variant = _resolve_variant(self.variant) + if ( + self.variant in BACKWARD_VARIANTS + and self.pipeline not in BACKWARD_PIPELINES + ): + self.pipeline = "compv3" + + @property + def tile_str(self) -> str: + return f"{self.tile_m}x{self.tile_n}x{self.tile_k}" + + @property + def wave_str(self) -> str: + return f"{self.wave_m}x{self.wave_n}x{self.wave_k}" + + @property + def warp_str(self) -> str: + return f"{self.warp_tile_m}x{self.warp_tile_n}x{self.warp_tile_k}" + + @property + def vec_str(self) -> str: + return f"{self.vector_size_a}x{self.vector_size_b}x{self.vector_size_c}" + + @property + def name(self) -> str: + return ( + f"grouped_conv_{self.variant}_{self.dtype}_{self.ndim_spatial}d_" + f"{self.tile_str}_{self.pipeline}" + ) + + def to_dict(self) -> dict: + """Convert to legacy dict format for codegen compatibility.""" + return { + "tile_config": { + "tile_m": [self.tile_m], + "tile_n": [self.tile_n], + "tile_k": [self.tile_k], + "wave_m": [self.wave_m], + "wave_n": [self.wave_n], + "wave_k": [self.wave_k], + "warp_tile_m": [self.warp_tile_m], + "warp_tile_n": [self.warp_tile_n], + "warp_tile_k": [self.warp_tile_k], + }, + "trait_config": { + "pipeline": [self.pipeline], + "epilogue": [self.epilogue], + "scheduler": [self.scheduler], + "pad_m": [self.pad_m], + "pad_n": [self.pad_n], + "pad_k": [self.pad_k], + "vector_size_a": [self.vector_size_a], + "vector_size_b": [self.vector_size_b], + "vector_size_c": [self.vector_size_c], + "block_per_cu": [self.block_per_cu], + "num_wave_groups": [self.num_wave_groups], + "num_groups_to_merge": [self.num_groups_to_merge], + }, + "variant": self.variant, + "ndim_spatial": self.ndim_spatial, + "arch": self.arch, + "layout": self.layout, + "dtype": self.dtype, + } + + def to_json_obj(self) -> dict: + """Serializable dict for JSON export.""" + return { + "name": self.name, + "signature": { + "variant": self.variant, + "dtype": self.dtype, + "ndim_spatial": self.ndim_spatial, + "layout": self.layout, + }, + "algorithm": { + "tile_m": self.tile_m, + "tile_n": self.tile_n, + "tile_k": self.tile_k, + "wave": self.wave_str, + "warp": self.warp_str, + "pipeline": self.pipeline, + "epilogue": self.epilogue, + "scheduler": self.scheduler, + "vector_sizes": [ + self.vector_size_a, + self.vector_size_b, + self.vector_size_c, + ], + "block_per_cu": self.block_per_cu, + "num_wave_groups": self.num_wave_groups, + "num_groups_to_merge": self.num_groups_to_merge, + }, + "arch": self.arch, + } + + def print_config(self, indent: str = " "): + print(f"{indent}GroupedConvKernelConfig:") + print(f"{indent} Variant: {self.variant} {self.ndim_spatial}D") + print(f"{indent} Dtype: {self.dtype}") + print(f"{indent} Layout: {self.layout}") + print(f"{indent} Arch: {self.arch}") + print(f"{indent} Tile: {self.tile_str}") + print(f"{indent} Wave: {self.wave_str}") + print(f"{indent} Warp: {self.warp_str}") + print(f"{indent} Pipeline: {self.pipeline}/{self.scheduler}/{self.epilogue}") + print(f"{indent} VecSizes: {self.vec_str}") + print( + f"{indent} BlockCU: {self.block_per_cu} WaveGroups: {self.num_wave_groups} MergeGroups: {self.num_groups_to_merge}" + ) + + +# ============================================================================= +# GroupedConvProblem +# ============================================================================= + + +@dataclass +class GroupedConvProblem: + """Runtime convolution problem specification. + + Describes the actual sizes of a convolution to be computed. + Matches the old ConvProblem from conv_utils.py. + """ + + N: int = 1 + C: int = 64 + K: int = 128 + G: int = 1 + + Hi: int = 28 + Wi: int = 28 + Di: int = 1 + + Y: int = 3 + X: int = 3 + Z: int = 1 + + stride_h: int = 1 + stride_w: int = 1 + stride_d: int = 1 + + pad_h: int = 0 + pad_w: int = 0 + pad_d: int = 0 + + dilation_h: int = 1 + dilation_w: int = 1 + dilation_d: int = 1 + + direction: str = "forward" + split_k: int = 1 + + @property + def Ho(self) -> int: + eff_y = (self.Y - 1) * self.dilation_h + 1 + return (self.Hi + 2 * self.pad_h - eff_y) // self.stride_h + 1 + + @property + def Wo(self) -> int: + eff_x = (self.X - 1) * self.dilation_w + 1 + return (self.Wi + 2 * self.pad_w - eff_x) // self.stride_w + 1 + + @property + def Do(self) -> int: + eff_z = (self.Z - 1) * self.dilation_d + 1 + return (self.Di + 2 * self.pad_d - eff_z) // self.stride_d + 1 + + @property + def is_3d(self) -> bool: + return self.Di > 1 or self.Z > 1 or self.pad_d > 0 + + @property + def ndim_spatial(self) -> int: + return 3 if self.is_3d else 2 + + @property + def flops(self) -> float: + """Total FLOPs for this convolution (any direction, same count).""" + c_per_group = self.C // self.G + if self.is_3d: + return ( + 2.0 + * self.N + * self.K + * self.Do + * self.Ho + * self.Wo + * c_per_group + * self.Z + * self.Y + * self.X + ) + return 2.0 * self.N * self.K * self.Ho * self.Wo * c_per_group * self.Y * self.X + + @property + def gflops(self) -> float: + return self.flops / 1e9 + + def input_shape(self) -> tuple: + """NHWGC or NDHWGC layout.""" + c_per_g = self.C // self.G + if self.is_3d: + return (self.N, self.Di, self.Hi, self.Wi, self.G, c_per_g) + return (self.N, self.Hi, self.Wi, self.G, c_per_g) + + def weight_shape(self) -> tuple: + """GKYXC or GKZYXC layout.""" + c_per_g = self.C // self.G + k_per_g = self.K // self.G + if self.is_3d: + return (self.G, k_per_g, self.Z, self.Y, self.X, c_per_g) + return (self.G, k_per_g, self.Y, self.X, c_per_g) + + def output_shape(self) -> tuple: + """NHWGK or NDHWGK layout.""" + k_per_g = self.K // self.G + if self.is_3d: + return (self.N, self.Do, self.Ho, self.Wo, self.G, k_per_g) + return (self.N, self.Ho, self.Wo, self.G, k_per_g) + + def print_problem(self, indent: str = " "): + dim_str = "3D" if self.is_3d else "2D" + print(f"{indent}GroupedConvProblem ({dim_str} {self.direction}):") + print(f"{indent} Batch: N={self.N}, G={self.G}") + print(f"{indent} Channels: C={self.C}, K={self.K}") + if self.is_3d: + print(f"{indent} Input: Di={self.Di}, Hi={self.Hi}, Wi={self.Wi}") + print(f"{indent} Filter: Z={self.Z}, Y={self.Y}, X={self.X}") + print(f"{indent} Output: Do={self.Do}, Ho={self.Ho}, Wo={self.Wo}") + else: + print(f"{indent} Input: Hi={self.Hi}, Wi={self.Wi}") + print(f"{indent} Filter: Y={self.Y}, X={self.X}") + print(f"{indent} Output: Ho={self.Ho}, Wo={self.Wo}") + print(f"{indent} GFLOPs: {self.gflops:.2f}") + + +# ============================================================================= +# GroupedConvProblemC (ctypes struct matching C++) +# ============================================================================= + + +class GroupedConvProblemC(ctypes.Structure): + """C structure matching ConvProblemC in conv_ctypes_lib.cpp.""" + + _fields_ = [ + ("N", ctypes.c_int), + ("G", ctypes.c_int), + ("C", ctypes.c_int), + ("K", ctypes.c_int), + ("input_d", ctypes.c_int), + ("input_h", ctypes.c_int), + ("input_w", ctypes.c_int), + ("filter_z", ctypes.c_int), + ("filter_y", ctypes.c_int), + ("filter_x", ctypes.c_int), + ("stride_d", ctypes.c_int), + ("stride_h", ctypes.c_int), + ("stride_w", ctypes.c_int), + ("pad_d", ctypes.c_int), + ("pad_h", ctypes.c_int), + ("pad_w", ctypes.c_int), + ("dilation_d", ctypes.c_int), + ("dilation_h", ctypes.c_int), + ("dilation_w", ctypes.c_int), + ("direction", ctypes.c_int), + ("split_k", ctypes.c_int), + ] + + @classmethod + def from_problem(cls, p: GroupedConvProblem) -> "GroupedConvProblemC": + c = cls() + c.N, c.G, c.C, c.K = p.N, p.G, p.C, p.K + c.input_d, c.input_h, c.input_w = p.Di, p.Hi, p.Wi + c.filter_z, c.filter_y, c.filter_x = p.Z, p.Y, p.X + c.stride_d, c.stride_h, c.stride_w = p.stride_d, p.stride_h, p.stride_w + c.pad_d, c.pad_h, c.pad_w = p.pad_d, p.pad_h, p.pad_w + c.dilation_d, c.dilation_h, c.dilation_w = ( + p.dilation_d, + p.dilation_h, + p.dilation_w, + ) + c.direction = DIRECTION_MAP.get(p.direction, 0) + c.split_k = getattr(p, "split_k", 1) + return c + + +# ============================================================================= +# GroupedConvResult +# ============================================================================= + + +@dataclass +class GroupedConvResult: + """Result of GPU convolution execution.""" + + success: bool = False + time_ms: float = 0.0 + tflops: float = 0.0 + output: Optional[np.ndarray] = None + error: str = "" + + +# ============================================================================= +# GroupedConvDispatcherLib +# ============================================================================= + + +class GroupedConvDispatcherLib: + """Wrapper for the compiled convolution dispatcher library. + + Provides Python interface to the C API in conv_ctypes_lib.cpp. + """ + + SEARCH_PATHS = [ + "build/examples/libdispatcher_conv_lib.so", + "build/bindings/libdispatcher_conv_lib.so", + "build/lib/libdispatcher_conv_lib.so", + ] + + def __init__(self, lib: ctypes.CDLL, path: Path): + self._lib = lib + self._path = path + self._setup_functions() + + def _setup_functions(self): + self._lib.conv_dispatcher_init.argtypes = [] + self._lib.conv_dispatcher_init.restype = ctypes.c_int + self._lib.conv_dispatcher_cleanup.argtypes = [] + self._lib.conv_dispatcher_cleanup.restype = ctypes.c_int + self._lib.conv_dispatcher_version.argtypes = [] + self._lib.conv_dispatcher_version.restype = ctypes.c_char_p + self._lib.conv_dispatcher_has_kernels.argtypes = [] + self._lib.conv_dispatcher_has_kernels.restype = ctypes.c_int + self._lib.conv_dispatcher_has_bwd_data.argtypes = [] + self._lib.conv_dispatcher_has_bwd_data.restype = ctypes.c_int + self._lib.conv_dispatcher_has_bwd_weight.argtypes = [] + self._lib.conv_dispatcher_has_bwd_weight.restype = ctypes.c_int + self._lib.conv_dispatcher_get_kernel_count.argtypes = [] + self._lib.conv_dispatcher_get_kernel_count.restype = ctypes.c_int + self._lib.conv_dispatcher_get_kernel_name.argtypes = [ + ctypes.c_int, + ctypes.c_char_p, + ctypes.c_int, + ] + self._lib.conv_dispatcher_get_kernel_name.restype = ctypes.c_int + self._lib.conv_dispatcher_is_supported.argtypes = [ + ctypes.POINTER(GroupedConvProblemC), + ] + self._lib.conv_dispatcher_is_supported.restype = ctypes.c_int + self._lib.conv_dispatcher_run.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.POINTER(GroupedConvProblemC), + ctypes.c_void_p, + ] + self._lib.conv_dispatcher_run.restype = ctypes.c_float + + @classmethod + def find(cls) -> Optional["GroupedConvDispatcherLib"]: + """Search standard paths for the conv library.""" + root = Path(__file__).parent.parent + for rel in cls.SEARCH_PATHS: + path = root / rel + if path.exists(): + try: + lib = ctypes.CDLL(str(path)) + return cls(lib, path) + except OSError: + continue + return None + + @property + def path(self) -> Path: + return self._path + + def initialize(self): + self._lib.conv_dispatcher_init() + + def cleanup(self): + self._lib.conv_dispatcher_cleanup() + + def version(self) -> str: + return self._lib.conv_dispatcher_version().decode() + + def has_forward(self) -> bool: + return self._lib.conv_dispatcher_has_kernels() != 0 + + def has_bwd_data(self) -> bool: + return self._lib.conv_dispatcher_has_bwd_data() != 0 + + def has_bwd_weight(self) -> bool: + return self._lib.conv_dispatcher_has_bwd_weight() != 0 + + def kernel_count(self) -> int: + return self._lib.conv_dispatcher_get_kernel_count() + + def kernel_names(self) -> List[str]: + names = [] + for i in range(self.kernel_count()): + buf = ctypes.create_string_buffer(256) + if self._lib.conv_dispatcher_get_kernel_name(i, buf, 256) == 0: + names.append(buf.value.decode()) + return names + + def is_supported(self, problem: GroupedConvProblem) -> bool: + pc = GroupedConvProblemC.from_problem(problem) + return self._lib.conv_dispatcher_is_supported(ctypes.byref(pc)) != 0 + + def run( + self, a_ptr: int, b_ptr: int, c_ptr: int, problem: GroupedConvProblem + ) -> float: + """Run convolution. Returns time_ms (>0 success, <0 error).""" + pc = GroupedConvProblemC.from_problem(problem) + return self._lib.conv_dispatcher_run( + a_ptr, b_ptr, c_ptr, ctypes.byref(pc), None + ) + + +# ============================================================================= +# GpuGroupedConvRunner +# ============================================================================= + + +class GpuGroupedConvRunner: + """High-level GPU convolution runner. + + Handles library loading, HIP memory management, and kernel execution. + Follows the same pattern as the old GpuConvRunner from conv_utils.py. + + Usage: + runner = GpuGroupedConvRunner() + if runner.is_available(): + result = runner.run(input_np, weight_np, problem) + print(f"Time: {result.time_ms:.4f} ms, TFLOPS: {result.tflops:.2f}") + """ + + HIP_MEMCPY_H2D = 1 + HIP_MEMCPY_D2H = 2 + + def __init__(self, lib_path: Optional[str] = None): + self._dispatch_lib: Optional[GroupedConvDispatcherLib] = None + self._hip = None + self._initialized = False + + try: + if lib_path: + lib = ctypes.CDLL(lib_path) + self._dispatch_lib = GroupedConvDispatcherLib(lib, Path(lib_path)) + else: + self._dispatch_lib = GroupedConvDispatcherLib.find() + + if self._dispatch_lib is None: + return + + self._hip = ctypes.CDLL("libamdhip64.so") + self._hip.hipMalloc.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), + ctypes.c_size_t, + ] + self._hip.hipMalloc.restype = ctypes.c_int + self._hip.hipFree.argtypes = [ctypes.c_void_p] + self._hip.hipFree.restype = ctypes.c_int + self._hip.hipMemcpy.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_int, + ] + self._hip.hipMemcpy.restype = ctypes.c_int + self._hip.hipDeviceSynchronize.argtypes = [] + self._hip.hipDeviceSynchronize.restype = ctypes.c_int + + self._dispatch_lib.initialize() + self._initialized = True + except Exception: + self._initialized = False + + def is_available(self) -> bool: + return self._initialized and self._dispatch_lib is not None + + @property + def library_path(self) -> Optional[str]: + if self._dispatch_lib: + return str(self._dispatch_lib.path) + return None + + @property + def lib(self) -> Optional[GroupedConvDispatcherLib]: + return self._dispatch_lib + + def run( + self, + input_np: np.ndarray, + weight_np: np.ndarray, + problem: GroupedConvProblem, + output_np: Optional[np.ndarray] = None, + ) -> GroupedConvResult: + """Run convolution on GPU. + + Args: + input_np: For forward: X (NHWGC). For bwd_data: dY. For bwd_weight: X. + weight_np: For forward: W (GKYXC). For bwd_data: W. For bwd_weight: dY. + problem: Problem specification. + output_np: Optional pre-allocated output buffer. + + Returns: + GroupedConvResult with success, time_ms, tflops, output. + """ + if not self.is_available(): + return GroupedConvResult(error="GPU not available") + + try: + # Determine output shape based on direction + d = problem.direction + if d == "bwd_data": + out_shape = problem.input_shape() + elif d == "bwd_weight": + out_shape = problem.weight_shape() + else: + out_shape = problem.output_shape() + + if output_np is None: + output_np = np.zeros(out_shape, dtype=input_np.dtype) + + output_size = output_np.nbytes + + # Allocate GPU memory + d_a, d_b, d_c = ctypes.c_void_p(), ctypes.c_void_p(), ctypes.c_void_p() + self._hip.hipMalloc(ctypes.byref(d_a), input_np.nbytes) + self._hip.hipMalloc(ctypes.byref(d_b), weight_np.nbytes) + self._hip.hipMalloc(ctypes.byref(d_c), output_size) + + # Host to device + self._hip.hipMemcpy( + d_a, input_np.ctypes.data, input_np.nbytes, self.HIP_MEMCPY_H2D + ) + self._hip.hipMemcpy( + d_b, weight_np.ctypes.data, weight_np.nbytes, self.HIP_MEMCPY_H2D + ) + self._hip.hipDeviceSynchronize() + + # Launch kernel + time_ms = self._dispatch_lib.run(d_a.value, d_b.value, d_c.value, problem) + self._hip.hipDeviceSynchronize() + + result = GroupedConvResult() + + if time_ms > 0: + # Device to host + self._hip.hipMemcpy( + output_np.ctypes.data, d_c, output_size, self.HIP_MEMCPY_D2H + ) + self._hip.hipDeviceSynchronize() + result.success = True + result.time_ms = time_ms + result.tflops = problem.flops / (time_ms * 1e9) + result.output = output_np + else: + result.error = ( + "unsupported" + if time_ms == -3.0 + else "no kernel" + if time_ms == -2.0 + else f"error (code {time_ms})" + ) + + # Free GPU memory + self._hip.hipFree(d_a) + self._hip.hipFree(d_b) + self._hip.hipFree(d_c) + + return result + + except Exception as e: + return GroupedConvResult(error=str(e)) + + def cleanup(self): + if self._dispatch_lib: + try: + self._dispatch_lib.cleanup() + except Exception: + pass + + +# ============================================================================= +# GroupedConvRegistry +# ============================================================================= + + +class GroupedConvRegistry: + """Collection of grouped conv kernel configs with JSON export/import.""" + + def __init__(self, name: str = "default"): + self.name = name + self._kernels: List[GroupedConvKernelConfig] = [] + + def add(self, config: GroupedConvKernelConfig): + self._kernels.append(config) + + @property + def kernels(self) -> List[GroupedConvKernelConfig]: + return list(self._kernels) + + def __len__(self) -> int: + return len(self._kernels) + + def select( + self, problem: "GroupedConvProblem", heuristic=None + ) -> Optional[GroupedConvKernelConfig]: + """Select the best kernel for a problem. + + Args: + problem: The convolution problem. + heuristic: Optional callable(problem) -> List[str] returning + ranked kernel name substrings. The registry tries + each in order; falls back to first matching kernel. + + Returns: + The best matching GroupedConvKernelConfig, or None. + """ + matching = [k for k in self._kernels if k.variant == problem.direction] + if not matching: + return None + + if heuristic is not None: + ranked = heuristic(problem) + for hint in ranked: + for k in matching: + if hint in k.name: + return k + + return matching[0] if matching else None + + def filter_by_variant(self, variant: str) -> "GroupedConvRegistry": + variant = _resolve_variant(variant) + reg = GroupedConvRegistry(f"{self.name}_{variant}") + for k in self._kernels: + if k.variant == variant: + reg.add(k) + return reg + + def filter_by_arch(self, arch: str) -> "GroupedConvRegistry": + reg = GroupedConvRegistry(f"{self.name}_{arch}") + for k in self._kernels: + if k.arch == arch: + reg.add(k) + return reg + + def to_json(self, indent: int = 2) -> str: + return json.dumps( + { + "name": self.name, + "kernels": [k.to_json_obj() for k in self._kernels], + }, + indent=indent, + ) + + @classmethod + def from_json(cls, json_str: str) -> "GroupedConvRegistry": + data = json.loads(json_str) + reg = cls(data.get("name", "imported")) + for kd in data.get("kernels", []): + sig = kd.get("signature", {}) + algo = kd.get("algorithm", {}) + wave = algo.get("wave", "2x2x1").split("x") + warp = algo.get("warp", "32x32x16").split("x") + vec = algo.get("vector_sizes", [4, 8, 8]) + reg.add( + GroupedConvKernelConfig( + variant=sig.get("variant", "forward"), + ndim_spatial=sig.get("ndim_spatial", 2), + dtype=sig.get("dtype", "fp16"), + layout=sig.get("layout", "nhwgc"), + arch=kd.get("arch", "gfx942"), + tile_m=algo.get("tile_m", 1), + tile_n=algo.get("tile_n", 128), + tile_k=algo.get("tile_k", 128), + wave_m=int(wave[0]), + wave_n=int(wave[1]), + wave_k=int(wave[2]), + warp_tile_m=int(warp[0]), + warp_tile_n=int(warp[1]), + warp_tile_k=int(warp[2]), + pipeline=algo.get("pipeline", "compv3"), + epilogue=algo.get("epilogue", "cshuffle"), + scheduler=algo.get("scheduler", "intrawave"), + vector_size_a=vec[0] if len(vec) > 0 else 4, + vector_size_b=vec[1] if len(vec) > 1 else 8, + vector_size_c=vec[2] if len(vec) > 2 else 8, + block_per_cu=algo.get("block_per_cu", 1), + num_wave_groups=algo.get("num_wave_groups", 1), + num_groups_to_merge=algo.get("num_groups_to_merge", 1), + ) + ) + return reg + + def build( + self, + verbose: bool = False, + max_workers: Optional[int] = None, + ) -> Dict[Tuple[str, int], "GpuGroupedConvRunner"]: + """Parallel JIT compile all kernels in this registry. + + Args: + verbose: Print progress during build. + max_workers: Max parallel codegen/compile processes (default: cpu_count capped at 8). + + Returns a dict mapping (variant, ndim_spatial) to a ready-to-use + GpuGroupedConvRunner. + """ + if not self._kernels: + return {} + + libs = setup_multiple_grouped_conv_dispatchers( + self._kernels, + verbose=verbose, + max_workers=max_workers, + ) + + runners: Dict[Tuple[str, int], GpuGroupedConvRunner] = {} + for cfg, lib in zip(self._kernels, libs): + if lib is None: + continue + key = (cfg.variant, cfg.ndim_spatial) + if key in runners: + continue + runner = GpuGroupedConvRunner(lib_path=str(lib.path)) + if runner.is_available(): + runners[key] = runner + return runners + + def print_registry(self, indent: str = " "): + print(f"{indent}Registry '{self.name}': {len(self)} kernels") + for i, k in enumerate(self._kernels): + print( + f"{indent} [{i}] {k.name} (valid={validate_grouped_conv_config(k.to_dict()).is_valid})" + ) + + +# ============================================================================= +# GroupedConvValidationResult +# ============================================================================= + + +@dataclass +class GroupedConvValidationResult(ValidationResultBase): + """Result of grouped conv kernel config validation.""" + + variant: str = "forward" + + def __init__( + self, + is_valid=True, + errors=None, + warnings=None, + suggested_fixes=None, + variant="forward", + ): + super().__init__( + is_valid=is_valid, + errors=errors or [], + warnings=warnings or [], + suggested_fixes=suggested_fixes or {}, + ) + self.variant = variant + + +# ============================================================================= +# Validation helpers (extracted from the original config extraction code) +# ============================================================================= + + +def _first(val): + if isinstance(val, list) and len(val) > 0: + return val[0] + return val + + +def _get_tile_config(config: dict) -> dict: + return config.get("tile_config") or {} + + +def _get_trait_config(config: dict) -> dict: + return config.get("trait_config") or {} + + +def _extract_wave_config(tile_config: dict) -> List[int]: + wm = tile_config.get("wave_m") or tile_config.get("warp_m") + wn = tile_config.get("wave_n") or tile_config.get("warp_n") + wk = tile_config.get("wave_k") or tile_config.get("warp_k") + if wm is not None and wn is not None and wk is not None: + return [_first(wm), _first(wn), _first(wk)] + return [2, 2, 1] + + +def _extract_warp_tile_config(tile_config: dict) -> List[int]: + wtm = tile_config.get("warp_tile_m") or tile_config.get("warp_m") + wtn = tile_config.get("warp_tile_n") or tile_config.get("warp_n") + wtk = tile_config.get("warp_tile_k") or tile_config.get("warp_k") + if wtm is not None and wtn is not None and wtk is not None: + return [_first(wtm), _first(wtn), _first(wtk)] + return [32, 32, 16] + + +def _extract_trait_values(trait_config: dict) -> Tuple[str, str, str]: + p = _first(trait_config.get("pipeline", "compv4")) + e = _first(trait_config.get("epilogue", "cshuffle")) + s = _first(trait_config.get("scheduler", "intrawave")) + if isinstance(p, list): + p = p[0] if p else "compv4" + if isinstance(e, list): + e = e[0] if e else "cshuffle" + if isinstance(s, list): + s = s[0] if s else "intrawave" + return (str(p), str(e), str(s)) + + +# ============================================================================= +# validate_grouped_conv_config / auto_correct_grouped_conv_config +# ============================================================================= + + +def validate_grouped_conv_config(config: dict) -> GroupedConvValidationResult: + """Validate a grouped conv kernel config dict. + + Accepts either a raw dict (legacy) or GroupedConvKernelConfig.to_dict() output. + """ + errors: List[str] = [] + warnings: List[str] = [] + suggested_fixes: Dict[str, Any] = {} + + required = ( + "tile_config", + "trait_config", + "variant", + "ndim_spatial", + "arch", + "layout", + ) + for key in required: + if key not in config: + errors.append(f"Missing required key: {key}") + if errors: + return GroupedConvValidationResult( + is_valid=False, + errors=errors, + warnings=warnings, + suggested_fixes=suggested_fixes, + variant=config.get("variant", "forward"), + ) + + tile_config = _get_tile_config(config) + trait_config = _get_trait_config(config) + variant = _first(config.get("variant", "forward")) + if isinstance(variant, list): + variant = variant[0] if variant else "forward" + variant = _resolve_variant(str(variant)) + + ndim_spatial = config.get("ndim_spatial") + arch = config.get("arch", "gfx942") + dtype = config.get("dtype", "fp16") + + if variant not in VALID_VARIANTS: + errors.append(f"Invalid variant: {variant}. Valid: {', '.join(VALID_VARIANTS)}") + suggested_fixes["variant"] = "forward" + + if ndim_spatial is not None: + ndim = ndim_spatial + if isinstance(ndim, list): + ndim = ndim[0] if ndim else 2 + if ndim not in VALID_NDIM_SPATIAL: + errors.append( + f"Invalid ndim_spatial: {ndim}. Valid: {', '.join(map(str, VALID_NDIM_SPATIAL))}" + ) + suggested_fixes["ndim_spatial"] = 2 + + pipeline, epilogue, scheduler = _extract_trait_values(trait_config) + if variant in BACKWARD_VARIANTS and pipeline not in BACKWARD_PIPELINES: + errors.append( + f"Backward variant '{variant}' requires pipeline compv3 or mem, got {pipeline}" + ) + suggested_fixes["pipeline"] = "compv3" + + ok, msg = validate_trait_combo(pipeline, epilogue, scheduler) + if not ok: + errors.append(msg) + suggested_fixes["scheduler"] = "intrawave" + + wave_cfg = _extract_wave_config(tile_config) + ok, msg = validate_wave_config(wave_cfg, arch) + if not ok: + errors.append(msg) + arch_data = get_arch_filter_data() + valid_waves = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + if valid_waves: + suggested_fixes["wave_m"] = valid_waves[0][0] + suggested_fixes["wave_n"] = valid_waves[0][1] + suggested_fixes["wave_k"] = valid_waves[0][2] + + warp_cfg = _extract_warp_tile_config(tile_config) + ok, msg = validate_warp_tile_config(warp_cfg, arch, dtype) + if not ok: + errors.append(msg) + arch_data = get_arch_filter_data() + acc = "int32" if dtype == "int8" else "fp32" + dtype_key = f"{dtype}_{dtype}_{acc}" + valid_tiles = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + if valid_tiles: + suggested_fixes["warp_tile_m"] = valid_tiles[0][0] + suggested_fixes["warp_tile_n"] = valid_tiles[0][1] + suggested_fixes["warp_tile_k"] = valid_tiles[0][2] + + arch_data = get_arch_filter_data() + if arch not in arch_data["supported_archs"]: + errors.append( + f"Unsupported architecture: {arch}. Supported: {', '.join(arch_data['supported_archs'])}" + ) + + return GroupedConvValidationResult( + is_valid=len(errors) == 0, + errors=errors, + warnings=warnings, + suggested_fixes=suggested_fixes, + variant=variant, + ) + + +def auto_correct_grouped_conv_config( + config: dict, +) -> Tuple[dict, GroupedConvValidationResult]: + """Auto-correct invalid grouped conv config. Returns (corrected, result).""" + result = validate_grouped_conv_config(config) + corrected = copy.deepcopy(config) + + if result.is_valid: + return corrected, result + + tile_config = corrected.setdefault("tile_config", {}) + trait_config = corrected.setdefault("trait_config", {}) + + wave_cfg = _extract_wave_config(tile_config) + arch = config.get("arch", "gfx942") + fixed_wave = auto_correct_wave(wave_cfg, arch) + tile_config["wave_m"] = fixed_wave[0] + tile_config["wave_n"] = fixed_wave[1] + tile_config["wave_k"] = fixed_wave[2] + + pipeline, epilogue, scheduler = _extract_trait_values(trait_config) + fixed_pipeline, fixed_scheduler = auto_correct_trait(pipeline, scheduler) + trait_config["pipeline"] = fixed_pipeline + trait_config["scheduler"] = fixed_scheduler + + variant = _first(config.get("variant", "forward")) + if isinstance(variant, list): + variant = variant[0] if variant else "forward" + variant = _resolve_variant(str(variant)) + if variant in BACKWARD_VARIANTS and fixed_pipeline not in BACKWARD_PIPELINES: + trait_config["pipeline"] = "compv3" + + if "warp_tile_m" in result.suggested_fixes: + tile_config["warp_tile_m"] = result.suggested_fixes["warp_tile_m"] + tile_config["warp_tile_n"] = result.suggested_fixes["warp_tile_n"] + tile_config["warp_tile_k"] = result.suggested_fixes["warp_tile_k"] + + result = validate_grouped_conv_config(corrected) + return corrected, result + + +def _run_hipcc_subprocess(args: dict) -> Tuple[bool, Optional[Path], str]: + """Run one hipcc compile+link job in a subprocess worker.""" + import subprocess + from pathlib import Path + + compile_cmd = args["compile_cmd"] + link_cmd = args["link_cmd"] + lib_path = Path(args["lib_path"]) + + try: + res_c = subprocess.run(compile_cmd, capture_output=True, text=True, timeout=300) + if res_c.returncode != 0: + return False, None, f"Compile failed: {res_c.stderr[:400]}" + + res_l = subprocess.run(link_cmd, capture_output=True, text=True, timeout=300) + if res_l.returncode != 0: + return False, None, f"Link failed: {res_l.stderr[:400]}" + + return True, lib_path, "" + except subprocess.TimeoutExpired: + return False, None, "Timeout" + except Exception as e: + return False, None, f"Error: {e}" + + +def _run_conv_codegen_subprocess(args: dict) -> Tuple[bool, Optional[str], str]: + """Run grouped-conv codegen once and return generated kernel header path.""" + import subprocess + from pathlib import Path + + out_dir = Path(args["output_dir"]) + out_dir.mkdir(parents=True, exist_ok=True) + + # Remove stale kernels so header discovery is exact for this invocation. + for stale in out_dir.glob("grouped_conv_*.hpp"): + stale.unlink(missing_ok=True) + for stale in out_dir.glob("include_all_grouped_conv_*.hpp"): + stale.unlink(missing_ok=True) + + try: + res = subprocess.run(args["cmd"], capture_output=True, text=True, timeout=300) + if res.returncode != 0: + err = (res.stderr or res.stdout or "").strip()[:500] + return False, None, f"Codegen failed: {err}" + + generated = sorted( + out_dir.glob("grouped_conv_*.hpp"), + key=lambda p: p.stat().st_mtime, + reverse=True, + ) + if not generated: + return False, None, "Codegen produced no grouped_conv_*.hpp header" + + return True, str(generated[0]), "" + except subprocess.TimeoutExpired: + return False, None, "Codegen timed out" + except Exception as e: + return False, None, f"Codegen error: {e}" + + +def _config_key(c: GroupedConvKernelConfig) -> Tuple[Any, ...]: + return ( + c.variant, + c.ndim_spatial, + c.dtype, + c.layout, + c.arch, + c.tile_m, + c.tile_n, + c.tile_k, + c.wave_m, + c.wave_n, + c.wave_k, + c.warp_tile_m, + c.warp_tile_n, + c.warp_tile_k, + c.pipeline, + c.epilogue, + c.scheduler, + ) + + +def _parse_triplet(value: str) -> Tuple[int, int, int]: + parts = value.split("x") + if len(parts) != 3: + raise ValueError(f"Invalid triplet: {value}") + return int(parts[0]), int(parts[1]), int(parts[2]) + + +def _list_arch_valid_grouped_conv_configs( + codegen_script: Path, + arch: str, + dtype: str, + variant: str, + ndim_spatial: int, +) -> List[GroupedConvKernelConfig]: + """Query codegen defaults for this (arch, dtype, variant, ndim) tuple.""" + import re + import sys + + cmd = [ + sys.executable, + str(codegen_script), + "--list-configs", + "--arch", + arch, + "--datatype", + dtype, + "--variant", + variant, + "--ndim", + str(ndim_spatial), + ] + res = subprocess.run(cmd, capture_output=True, text=True, timeout=180) + if res.returncode != 0: + return [] + + # Example: + # grouped_conv_fwd_fp16_nhwgc_2d_compv3_cshuffle_intrawave_128x128x32_2x2x1_32x32x16 + name_re = re.compile( + r"^grouped_conv_(fwd|bwd_data|bwd_weight|bwdd|bwdw)_([a-z0-9]+)_([a-z0-9]+)_([123])d_" + r"([a-z0-9]+)_([a-z0-9]+)_([a-z0-9]+)_" + r"([0-9]+x[0-9]+x[0-9]+)_([0-9]+x[0-9]+x[0-9]+)_([0-9]+x[0-9]+x[0-9]+)" + r"(?:_.*)?$" + ) + short_to_variant = { + "fwd": "forward", + "bwd_data": "bwd_data", + "bwd_weight": "bwd_weight", + "bwdd": "bwd_data", + "bwdw": "bwd_weight", + } + + out: List[GroupedConvKernelConfig] = [] + seen = set() + for raw in res.stdout.splitlines(): + line = raw.strip() + if not line.startswith("- grouped_conv_"): + continue + name = line[2:].strip() + m = name_re.match(name) + if not m: + continue + + v_short, dt, layout, ndim, pipe, epi, sched, tile_s, wave_s, warp_s = m.groups() + tm, tn, tk = _parse_triplet(tile_s) + wm, wn, wk = _parse_triplet(wave_s) + wtm, wtn, wtk = _parse_triplet(warp_s) + + cfg = GroupedConvKernelConfig( + variant=short_to_variant[v_short], + ndim_spatial=int(ndim), + dtype=dt, + layout=layout, + arch=arch, + tile_m=tm, + tile_n=tn, + tile_k=tk, + wave_m=wm, + wave_n=wn, + wave_k=wk, + warp_tile_m=wtm, + warp_tile_n=wtn, + warp_tile_k=wtk, + pipeline=pipe, + epilogue=epi, + scheduler=sched, + ) + key = _config_key(cfg) + if key not in seen: + out.append(cfg) + seen.add(key) + + return out + + +def _select_best_arch_valid_conv_config( + requested: GroupedConvKernelConfig, + candidates: List[GroupedConvKernelConfig], +) -> GroupedConvKernelConfig: + """Pick nearest arch-valid config while preferring trait exact matches.""" + + def score(c: GroupedConvKernelConfig) -> Tuple[int, int, int, int, int, int]: + tile_delta = ( + abs(c.tile_m - requested.tile_m) + + abs(c.tile_n - requested.tile_n) + + abs(c.tile_k - requested.tile_k) + ) + wave_delta = ( + abs(c.wave_m - requested.wave_m) + + abs(c.wave_n - requested.wave_n) + + abs(c.wave_k - requested.wave_k) + ) + warp_tile_delta = ( + abs(c.warp_tile_m - requested.warp_tile_m) + + abs(c.warp_tile_n - requested.warp_tile_n) + + abs(c.warp_tile_k - requested.warp_tile_k) + ) + return ( + 0 if c.pipeline == requested.pipeline else 1, + 0 if c.scheduler == requested.scheduler else 1, + 0 if c.epilogue == requested.epilogue else 1, + tile_delta, + wave_delta, + warp_tile_delta, + ) + + best = min(candidates, key=score) + selected = copy.deepcopy(best) + selected.arch = requested.arch + return selected + + +def _write_single_conv_dispatch_header( + config: GroupedConvKernelConfig, + kernel_header: Path, + dispatch_header: Path, +) -> None: + """Create a tiny dispatch header consumed by conv_ctypes_lib.cpp.""" + macros: List[str] = [] + aliases: List[str] = [] + + if config.variant == "forward": + kernel_name_symbol = "CONV_FWD_KERNEL_NAME" + if config.ndim_spatial == 3: + macros.append("#define CONV_FWD_3D_AVAILABLE 1") + aliases.append("using ConvFwd3dLauncher = SelectedConvKernelLauncher;") + else: + macros.append("#define CONV_FWD_2D_AVAILABLE 1") + elif config.variant == "bwd_data": + kernel_name_symbol = "CONV_BWD_DATA_KERNEL_NAME" + if config.ndim_spatial == 3: + macros.append("#define CONV_BWD_DATA_3D_AVAILABLE 1") + aliases.append("using ConvBwdData3dLauncher = SelectedConvBwdDataLauncher;") + else: + macros.append("#define CONV_BWD_DATA_2D_AVAILABLE 1") + else: + kernel_name_symbol = "CONV_BWD_WEIGHT_KERNEL_NAME" + if config.ndim_spatial == 3: + macros.append("#define CONV_BWD_WEIGHT_3D_AVAILABLE 1") + aliases.append( + "using ConvBwdWeight3dLauncher = SelectedConvBwdWeightLauncher;" + ) + else: + macros.append("#define CONV_BWD_WEIGHT_2D_AVAILABLE 1") + + content = ( + "// Auto-generated single-kernel dispatch header for Python JIT\n" + "#pragma once\n\n" + f'#include "{kernel_header.name}"\n\n' + + "\n".join(macros) + + "\n\n" + + "\n".join(aliases) + + "\n\n" + + f"static const char* CONV_KERNEL_NAMES[] = {{{kernel_name_symbol}}};\n" + + "static constexpr int CONV_KERNEL_COUNT = 1;\n" + ) + dispatch_header.write_text(content) + + +class GroupedConvCodegenRunner: + """Generate and compile grouped-conv JIT libraries in parallel.""" + + def __init__(self, max_workers: Optional[int] = None): + import multiprocessing + + self.max_workers = max_workers or min(multiprocessing.cpu_count(), 8) + self.root = Path(__file__).parent.parent + self.build_dir = self.root / "build" + self.codegen_script = self.root / "codegen" / "unified_grouped_conv_codegen.py" + + def generate_and_compile_parallel( + self, + configs: List[GroupedConvKernelConfig], + verbose: bool = True, + ) -> List[Optional[Path]]: + import sys + from concurrent.futures import ProcessPoolExecutor, as_completed + + if not configs: + return [] + + if not self.build_dir.exists(): + self.build_dir.mkdir(parents=True, exist_ok=True) + + ctypes_source = self.root / "bindings" / "ctypes" / "conv_ctypes_lib.cpp" + static_lib = self.build_dir / "libck_tile_dispatcher.a" + jit_root = self.build_dir / "generated_kernels" / "python_jit" + jit_root.mkdir(parents=True, exist_ok=True) + (self.build_dir / "examples").mkdir(parents=True, exist_ok=True) + + if not self.codegen_script.exists(): + if verbose: + print(f"Codegen script missing: {self.codegen_script}") + return [None] * len(configs) + if not ctypes_source.exists() or not static_lib.exists(): + if verbose: + print("Missing conv ctypes source or static dispatcher library") + return [None] * len(configs) + + if verbose: + print( + f"Generating {len(configs)} grouped-conv kernels in parallel " + f"(workers={self.max_workers})..." + ) + + gen_jobs: List[Dict[str, Any]] = [] + job_dirs: List[Path] = [] + for i, c in enumerate(configs): + cfg_dir = jit_root / f"cfg_{i}" + cfg_dir.mkdir(parents=True, exist_ok=True) + job_dirs.append(cfg_dir) + + cmd = [ + sys.executable, + str(self.codegen_script), + "--output", + str(cfg_dir), + "--datatype", + c.dtype, + "--variant", + c.variant, + "--ndim", + str(c.ndim_spatial), + "--arch", + c.arch, + "--tile-m", + str(c.tile_m), + "--tile-n", + str(c.tile_n), + "--tile-k", + str(c.tile_k), + "--warp-m", + str(c.wave_m), + "--warp-n", + str(c.wave_n), + "--warp-k", + str(c.wave_k), + "--warp-tile-m", + str(c.warp_tile_m), + "--warp-tile-n", + str(c.warp_tile_n), + "--warp-tile-k", + str(c.warp_tile_k), + "--pipeline", + c.pipeline, + "--scheduler", + c.scheduler, + "--epilogue", + c.epilogue, + ] + gen_jobs.append({"cmd": cmd, "output_dir": str(cfg_dir)}) + + generated_headers: List[Optional[Path]] = [None] * len(configs) + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(_run_conv_codegen_subprocess, job): idx + for idx, job in enumerate(gen_jobs) + } + for future in as_completed(futures): + idx = futures[future] + ok, header_path, err = future.result() + if ok and header_path: + generated_headers[idx] = Path(header_path) + if verbose: + print(f" OK [{idx}] codegen: {Path(header_path).name}") + else: + if verbose: + print(f" FAIL [{idx}] codegen: {err}") + + if verbose: + compile_count = sum(1 for h in generated_headers if h is not None) + print( + f"Compiling {compile_count} grouped-conv libraries in parallel " + f"(workers={self.max_workers})..." + ) + + compile_jobs: List[Dict[str, Any]] = [] + compile_to_input_index: Dict[int, int] = {} + for i, c in enumerate(configs): + hdr_path = generated_headers[i] + if hdr_path is None: + continue + + cfg_dir = job_dirs[i] + dispatch_header = cfg_dir / "conv_python_dispatch.hpp" + _write_single_conv_dispatch_header(c, hdr_path, dispatch_header) + + lib_name = ( + f"libdispatcher_conv_{c.variant}_{c.ndim_spatial}d_{c.dtype}_" + f"{c.tile_str}_{c.wave_str}_{c.warp_str}_{c.pipeline}_{c.scheduler}.so" + ) + lib_path = self.build_dir / "examples" / lib_name + obj_file = lib_path.with_suffix(".o") + + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-c", + "-fPIC", + "-O3", + f"-I{self.root / 'include'}", + f"-I{self.root.parent / 'include'}", + f"-I{self.root.parent}", + f"-I{cfg_dir}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", + f"-include{dispatch_header}", + "-D__HIP_PLATFORM_AMD__", + f"--offload-arch={c.arch}", + f'-DGFX_ARCH="{c.arch}"', + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(obj_file), + ] + link_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + f"--offload-arch={c.arch}", + "--hip-link", + str(obj_file), + str(static_lib), + "-o", + str(lib_path), + ] + + compile_to_input_index[len(compile_jobs)] = i + compile_jobs.append( + { + "compile_cmd": compile_cmd, + "link_cmd": link_cmd, + "lib_path": str(lib_path), + "config_name": c.name, + } + ) + + results_map: Dict[int, Optional[Path]] = {i: None for i in range(len(configs))} + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(_run_hipcc_subprocess, job): j + for j, job in enumerate(compile_jobs) + } + for future in as_completed(futures): + job_idx = futures[future] + idx = compile_to_input_index[job_idx] + success, lib_path, err = future.result() + if success and lib_path: + results_map[idx] = Path(lib_path) + if verbose: + status = "OK" if success else f"FAIL ({err})" + name = ( + Path(lib_path).name + if success and lib_path + else compile_jobs[job_idx]["config_name"] + ) + print(f" {status} {name}") + + return [results_map.get(i) for i in range(len(configs))] + + +# ============================================================================= +# Convenience functions +# ============================================================================= + + +def get_grouped_conv_default_config( + variant: str = "forward", + ndim_spatial: int = 2, + arch: str = "gfx942", + dtype: str = "fp16", +) -> GroupedConvKernelConfig: + """Return a valid default GroupedConvKernelConfig.""" + return GroupedConvKernelConfig( + variant=variant, + ndim_spatial=ndim_spatial, + arch=arch, + dtype=dtype, + ) + + +def format_grouped_conv_summary(config) -> str: + """Format a config (dict or GroupedConvKernelConfig) into a human-readable string.""" + if isinstance(config, GroupedConvKernelConfig): + lines = [ + f"Grouped Conv Config: {config.variant} {config.ndim_spatial}D", + f" Arch: {config.arch}", + f" Layout: {config.layout}", + f" Dtype: {config.dtype}", + f" Tile: {config.tile_str}", + f" Wave: {config.wave_str}", + f" Warp: {config.warp_str}", + f" Traits: pipeline={config.pipeline} epilogue={config.epilogue} scheduler={config.scheduler}", + ] + return "\n".join(lines) + + # Legacy dict support + tile_config = _get_tile_config(config) if isinstance(config, dict) else {} + trait_config = _get_trait_config(config) if isinstance(config, dict) else {} + variant = config.get("variant", "?") if isinstance(config, dict) else "?" + ndim = config.get("ndim_spatial", "?") if isinstance(config, dict) else "?" + arch = config.get("arch", "?") if isinstance(config, dict) else "?" + layout = config.get("layout", "?") if isinstance(config, dict) else "?" + dtype = config.get("dtype", "fp16") if isinstance(config, dict) else "fp16" + + lines = [f"Grouped Conv Config: {variant} {ndim}D"] + lines.append(f" Arch: {arch}") + lines.append(f" Layout: {layout}") + lines.append(f" Dtype: {dtype}") + + if tile_config: + wave = _extract_wave_config(tile_config) + warp = _extract_warp_tile_config(tile_config) + lines.append( + f" Tile: M={_first(tile_config.get('tile_m', 1))} N={_first(tile_config.get('tile_n', 128))} K={_first(tile_config.get('tile_k', 128))}" + ) + lines.append(f" Wave: {wave[0]}x{wave[1]}x{wave[2]}") + lines.append(f" Warp: {warp[0]}x{warp[1]}x{warp[2]}") + + if trait_config: + pipeline = _first(trait_config.get("pipeline", "?")) + epilogue = _first(trait_config.get("epilogue", "?")) + scheduler = _first(trait_config.get("scheduler", "?")) + lines.append( + f" Traits: pipeline={pipeline} epilogue={epilogue} scheduler={scheduler}" + ) + + return "\n".join(lines) if lines else "(empty config)" + + +def setup_multiple_grouped_conv_dispatchers( + configs: List[GroupedConvKernelConfig], + verbose: bool = True, + max_workers: Optional[int] = None, +) -> List[Optional[GroupedConvDispatcherLib]]: + """ + Setup multiple grouped-conv dispatchers in parallel. + + This keeps architecture filtering strict: + 1. Validate + auto-correct each requested config + 2. Query codegen's arch-valid config set for each (arch, dtype, variant, ndim) + 3. Map each request to nearest valid config + 4. Parallel codegen + parallel compile + """ + if not configs: + return [] + + codegen_script = ( + Path(__file__).parent.parent / "codegen" / "unified_grouped_conv_codegen.py" + ) + arch_valid_cache: Dict[ + Tuple[str, str, str, int], List[GroupedConvKernelConfig] + ] = {} + + selected_configs: List[Optional[GroupedConvKernelConfig]] = [] + for i, original in enumerate(configs): + c = copy.deepcopy(original) + + val = validate_grouped_conv_config(c.to_dict()) + if not val.is_valid: + corrected, corrected_result = auto_correct_grouped_conv_config(c.to_dict()) + if not corrected_result.is_valid: + if verbose: + print(f" FAIL [{i}] config remains invalid after auto-correct") + selected_configs.append(None) + continue + + tile_cfg = corrected.get("tile_config", {}) + trait_cfg = corrected.get("trait_config", {}) + c.variant = _resolve_variant( + str(_first(corrected.get("variant", c.variant))) + ) + c.ndim_spatial = int(_first(corrected.get("ndim_spatial", c.ndim_spatial))) + c.arch = str(corrected.get("arch", c.arch)) + c.layout = str(corrected.get("layout", c.layout)) + c.dtype = str(corrected.get("dtype", c.dtype)) + c.tile_m = int(_first(tile_cfg.get("tile_m", c.tile_m))) + c.tile_n = int(_first(tile_cfg.get("tile_n", c.tile_n))) + c.tile_k = int(_first(tile_cfg.get("tile_k", c.tile_k))) + c.wave_m = int(_first(tile_cfg.get("wave_m", c.wave_m))) + c.wave_n = int(_first(tile_cfg.get("wave_n", c.wave_n))) + c.wave_k = int(_first(tile_cfg.get("wave_k", c.wave_k))) + c.warp_tile_m = int(_first(tile_cfg.get("warp_tile_m", c.warp_tile_m))) + c.warp_tile_n = int(_first(tile_cfg.get("warp_tile_n", c.warp_tile_n))) + c.warp_tile_k = int(_first(tile_cfg.get("warp_tile_k", c.warp_tile_k))) + c.pipeline = str(_first(trait_cfg.get("pipeline", c.pipeline))) + c.scheduler = str(_first(trait_cfg.get("scheduler", c.scheduler))) + c.epilogue = str(_first(trait_cfg.get("epilogue", c.epilogue))) + + cache_key = (c.arch, c.dtype, c.variant, c.ndim_spatial) + if cache_key not in arch_valid_cache: + arch_valid_cache[cache_key] = _list_arch_valid_grouped_conv_configs( + codegen_script=codegen_script, + arch=c.arch, + dtype=c.dtype, + variant=c.variant, + ndim_spatial=c.ndim_spatial, + ) + if verbose and not arch_valid_cache[cache_key]: + print( + f" FAIL [{i}] no arch-valid configs listed for " + f"{c.arch}/{c.dtype}/{c.variant}/{c.ndim_spatial}d" + ) + + candidates = arch_valid_cache[cache_key] + if not candidates: + selected_configs.append(None) + continue + + selected = _select_best_arch_valid_conv_config(c, candidates) + if verbose and _config_key(selected) != _config_key(c): + print( + f" INFO [{i}] mapped to arch-valid config: " + f"{selected.tile_str} {selected.wave_str} {selected.warp_str} " + f"{selected.pipeline}/{selected.scheduler}/{selected.epilogue}" + ) + selected_configs.append(selected) + + unique_configs: List[GroupedConvKernelConfig] = [] + unique_index_by_key: Dict[Tuple[Any, ...], int] = {} + input_to_unique: List[Optional[int]] = [] + for cfg in selected_configs: + if cfg is None: + input_to_unique.append(None) + continue + key = _config_key(cfg) + if key not in unique_index_by_key: + unique_index_by_key[key] = len(unique_configs) + unique_configs.append(cfg) + input_to_unique.append(unique_index_by_key[key]) + + runner = GroupedConvCodegenRunner(max_workers=max_workers) + unique_lib_paths = runner.generate_and_compile_parallel( + unique_configs, verbose=verbose + ) + + libs: List[Optional[GroupedConvDispatcherLib]] = [] + loaded_cache: Dict[int, Optional[GroupedConvDispatcherLib]] = {} + for input_idx, unique_idx in enumerate(input_to_unique): + if unique_idx is None: + libs.append(None) + continue + + if unique_idx in loaded_cache: + libs.append(loaded_cache[unique_idx]) + continue + + path = ( + unique_lib_paths[unique_idx] if unique_idx < len(unique_lib_paths) else None + ) + disp: Optional[GroupedConvDispatcherLib] = None + if path and path.exists(): + try: + lib = ctypes.CDLL(str(path)) + disp = GroupedConvDispatcherLib(lib, path) + disp.initialize() + except Exception as e: + if verbose: + print(f" FAIL [{input_idx}] failed to load {path}: {e}") + loaded_cache[unique_idx] = disp + libs.append(disp) + + return libs + + +def detect_gpu_arch() -> str: + """Detect GPU architecture using rocminfo.""" + try: + out = subprocess.check_output( + ["rocminfo"], stderr=subprocess.DEVNULL, text=True + ) + for line in out.split("\n"): + if "gfx" in line.lower() and "name:" in line.lower(): + for part in line.split(): + if part.startswith("gfx"): + return part + except Exception: + pass + return "gfx942" diff --git a/dispatcher/scripts/compile_gemm_examples.py b/dispatcher/scripts/compile_gemm_examples.py index b19c18a13a..98ba18ab51 100644 --- a/dispatcher/scripts/compile_gemm_examples.py +++ b/dispatcher/scripts/compile_gemm_examples.py @@ -94,17 +94,17 @@ def find_hipcc() -> str: def extract_conv_kernel_declarations(source_file: Path) -> list: - """Extract CONVOLUTION kernel declarations from C++ source file. + """Extract GROUPED CONVOLUTION kernel declarations from C++ source file. - Supports DECL_CONV_KERNEL_SET macro with ConvSig/ConvAlgo pattern. + Supports DECL_GROUPED_CONV_KERNEL_SET macro with ConvSig/ConvAlgo pattern. Extracts all parameters: dtype, layout, conv_type, dims, tile, wave, warp, pipeline, scheduler. """ content = source_file.read_text() declarations = [] seen = set() - # Pattern: DECL_CONV_KERNEL_SET(name, .add(...).add(...)) - set_pattern = r"DECL_CONV_KERNEL_SET\s*\(\s*(\w+)\s*,([^;]+)\)" + # Pattern: DECL_GROUPED_CONV_KERNEL_SET(name, .add(...).add(...)) + set_pattern = r"DECL_GROUPED_CONV_KERNEL_SET\s*\(\s*(\w+)\s*,([^;]+)\)" for match in re.finditer(set_pattern, content, re.DOTALL): set_name = match.group(1) @@ -396,24 +396,23 @@ def expand_conv_declaration_with_arch_filter(decl: dict, arch: str = "gfx942") - def generate_conv_kernels(declarations: list, gpu_target: str = "gfx942") -> int: - """Generate convolution kernels using unified_conv_codegen.""" + """Generate grouped convolution kernels using unified_grouped_conv_codegen.""" kernel_dir = get_generated_kernels_dir() kernel_dir.mkdir(parents=True, exist_ok=True) - # Import conv codegen codegen_dir = get_dispatcher_root() / "codegen" sys.path.insert(0, str(codegen_dir)) try: - from unified_conv_codegen import ( - UnifiedConvCodegen, - ConvKernelConfig, - ConvVariant, + from unified_grouped_conv_codegen import ( + UnifiedGroupedConvCodegen as UnifiedConvCodegen, + GroupedConvKernelConfig as ConvKernelConfig, + GroupedConvVariant as ConvVariant, TileConfig, - TraitConfig, + GroupedConvTraitConfig as TraitConfig, ) except ImportError as e: - print_error(f" Failed to import conv codegen: {e}") + print_error(f" Failed to import grouped conv codegen: {e}") return 0 codegen = UnifiedConvCodegen(kernel_dir) @@ -1564,9 +1563,9 @@ def build_exact_conv_kernel_filename(decl: dict) -> str: if conv_type == "forward": type_prefix = "fwd" elif conv_type == "bwd_data": - type_prefix = "bwdd" + type_prefix = "bwd_data" elif conv_type == "bwd_weight": - type_prefix = "bwdw" + type_prefix = "bwd_weight" else: type_prefix = conv_type @@ -1601,9 +1600,9 @@ def generate_specific_conv_kernel(decl: dict, gpu_target: str = "gfx942") -> boo else: variant = "forward" - # Use unified_conv_codegen + # Use unified_grouped_conv_codegen codegen_dir = get_dispatcher_root() / "codegen" - codegen_script = codegen_dir / "unified_conv_codegen.py" + codegen_script = codegen_dir / "unified_grouped_conv_codegen.py" output_dir = get_generated_kernels_dir() cmd = [ @@ -1661,9 +1660,9 @@ def find_conv_kernel_header(decl: dict, gpu_target: str = "gfx942") -> Path: if conv_type == "forward": type_prefix = "fwd" elif conv_type == "bwd_data": - type_prefix = "bwdd" + type_prefix = "bwd_data" elif conv_type == "bwd_weight": - type_prefix = "bwdw" + type_prefix = "bwd_weight" else: type_prefix = conv_type @@ -1865,7 +1864,9 @@ In your C++ code, declare kernels like: if not gemm_declarations and not conv_declarations: print_error(" No kernel declarations found!") - print(" Add DECL_KERNEL_SET for GEMM or DECL_CONV_KERNEL_SET for Conv") + print( + " Add DECL_KERNEL_SET for GEMM or DECL_GROUPED_CONV_KERNEL_SET for Grouped Conv" + ) return 1 # Handle GEMM declarations @@ -1913,7 +1914,7 @@ In your C++ code, declare kernels like: is_valid, error_msg = validate_kernel_config(decl, arch) if not is_valid: - print(f"\n ⚠ Invalid configuration: {decl_name}") + print(f"\n WARNING Invalid configuration: {decl_name}") # Parse the error and show specific auto-corrections corrections = [] @@ -1926,7 +1927,7 @@ In your C++ code, declare kernels like: decl["wave_m"] = -1 decl["wave_n"] = -1 corrections.append( - f"wave: {original_values['wave']} → [wildcard expansion]" + f"wave: {original_values['wave']} -> [wildcard expansion]" ) if "warp tile" in error_msg.lower(): @@ -1936,7 +1937,7 @@ In your C++ code, declare kernels like: decl["warp_m"] = -1 decl["warp_n"] = -1 corrections.append( - f"warp_tile: {original_values['warp']} → [wildcard expansion]" + f"warp_tile: {original_values['warp']} -> [wildcard expansion]" ) if "trait combination" in error_msg.lower(): @@ -1945,16 +1946,16 @@ In your C++ code, declare kernels like: decl["pipeline"] = "*" decl["scheduler"] = "*" corrections.append( - f"pipeline: {original_values['pipeline']} → [wildcard expansion]" + f"pipeline: {original_values['pipeline']} -> [wildcard expansion]" ) corrections.append( - f"scheduler: {original_values['scheduler']} → [wildcard expansion]" + f"scheduler: {original_values['scheduler']} -> [wildcard expansion]" ) # Print the auto-corrections print(" AUTO-CORRECTION:") for corr in corrections: - print(f" • {corr}") + print(f" - {corr}") auto_corrections.append((decl_name, corrections)) invalid_count += 1 @@ -1962,15 +1963,15 @@ In your C++ code, declare kernels like: if invalid_count > 0: print( - f"\n ⚠ {invalid_count} invalid config(s) auto-corrected via wildcard expansion" + f"\n WARNING {invalid_count} invalid config(s) auto-corrected via wildcard expansion" ) if wildcard_count > 0: print( - f" ✓ {len(gemm_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" + f" OK {len(gemm_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" ) else: - print(f" ✓ All {len(gemm_declarations)} configurations valid") + print(f" OK All {len(gemm_declarations)} configurations valid") # Expand GEMM declarations (for wildcards) print("\n Expanding wildcards to valid configurations...") @@ -1994,7 +1995,7 @@ In your C++ code, declare kernels like: wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" print( - f" → wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" + f" -> wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" ) if len(expanded) > 3: print(f" ... and {len(expanded) - 3} more") @@ -2002,11 +2003,11 @@ In your C++ code, declare kernels like: exp = expanded[0] wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" - print(f" {decl_name}: → wave={wave_str}, warp={warp_str}") + print(f" {decl_name}: -> wave={wave_str}, warp={warp_str}") if len(expanded_gemm) > len(gemm_declarations): print( - f"\n Total: {len(gemm_declarations)} declarations → {len(expanded_gemm)} configurations" + f"\n Total: {len(gemm_declarations)} declarations -> {len(expanded_gemm)} configurations" ) gemm_declarations = expanded_gemm @@ -2054,7 +2055,7 @@ In your C++ code, declare kernels like: is_valid, error_msg = validate_conv_kernel_config(decl, arch) if not is_valid: - print(f"\n ⚠ Invalid conv configuration: {decl_name}") + print(f"\n WARNING Invalid conv configuration: {decl_name}") # Parse the error and show specific auto-corrections corrections = [] @@ -2067,7 +2068,7 @@ In your C++ code, declare kernels like: decl["wave_m"] = -1 decl["wave_n"] = -1 corrections.append( - f"wave: {original_values['wave']} → [wildcard expansion]" + f"wave: {original_values['wave']} -> [wildcard expansion]" ) if "warp tile" in error_msg.lower(): @@ -2077,7 +2078,7 @@ In your C++ code, declare kernels like: decl["warp_m"] = -1 decl["warp_n"] = -1 corrections.append( - f"warp_tile: {original_values['warp']} → [wildcard expansion]" + f"warp_tile: {original_values['warp']} -> [wildcard expansion]" ) if "trait combination" in error_msg.lower(): @@ -2086,16 +2087,16 @@ In your C++ code, declare kernels like: decl["pipeline"] = "*" decl["scheduler"] = "*" corrections.append( - f"pipeline: {original_values['pipeline']} → [wildcard expansion]" + f"pipeline: {original_values['pipeline']} -> [wildcard expansion]" ) corrections.append( - f"scheduler: {original_values['scheduler']} → [wildcard expansion]" + f"scheduler: {original_values['scheduler']} -> [wildcard expansion]" ) # Print the auto-corrections print(" AUTO-CORRECTION:") for corr in corrections: - print(f" • {corr}") + print(f" - {corr}") auto_corrections.append((decl_name, corrections)) invalid_count += 1 @@ -2103,15 +2104,15 @@ In your C++ code, declare kernels like: if invalid_count > 0: print( - f"\n ⚠ {invalid_count} invalid config(s) auto-corrected via wildcard expansion" + f"\n WARNING {invalid_count} invalid config(s) auto-corrected via wildcard expansion" ) if wildcard_count > 0: print( - f" ✓ {len(conv_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" + f" OK {len(conv_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" ) else: - print(f" ✓ All {len(conv_declarations)} configurations valid") + print(f" OK All {len(conv_declarations)} configurations valid") # Expand Conv declarations (for wildcards) print("\n Expanding wildcards to valid configurations...") @@ -2134,7 +2135,7 @@ In your C++ code, declare kernels like: wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" print( - f" → wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" + f" -> wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" ) if len(expanded) > 3: print(f" ... and {len(expanded) - 3} more") @@ -2142,11 +2143,11 @@ In your C++ code, declare kernels like: exp = expanded[0] wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" - print(f" {decl_name}: → wave={wave_str}, warp={warp_str}") + print(f" {decl_name}: -> wave={wave_str}, warp={warp_str}") if len(expanded_conv) > len(conv_declarations): print( - f"\n Total: {len(conv_declarations)} declarations → {len(expanded_conv)} configurations" + f"\n Total: {len(conv_declarations)} declarations -> {len(expanded_conv)} configurations" ) conv_declarations = expanded_conv diff --git a/dispatcher/scripts/compile_grouped_conv_examples.py b/dispatcher/scripts/compile_grouped_conv_examples.py new file mode 100644 index 0000000000..32fe70a2de --- /dev/null +++ b/dispatcher/scripts/compile_grouped_conv_examples.py @@ -0,0 +1,882 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Self-contained build script for C++ grouped convolution examples. + +Parses DECL_GROUPED_CONV_KERNEL_SET declarations from source files, +generates the needed kernels, and compiles the example. + +Includes validation and auto-correction via wildcard expansion. + +Usage: + python3 compile_grouped_conv_examples.py examples/grouped_conv/cpp/02_grouped_conv_forward.cpp + python3 compile_grouped_conv_examples.py examples/grouped_conv/cpp/03_grouped_conv_validation.cpp --no-compile +""" + +import argparse +import os +import re +import subprocess +import sys +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path +from typing import Optional + +# Setup paths +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +CK_ROOT = DISPATCHER_DIR.parent + +sys.path.insert(0, str(DISPATCHER_DIR / "python")) +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) + +from dispatcher_common import ( # noqa: E402 + print_phase, + print_success, + print_error, + print_info, + find_hipcc, + get_arch_filter_data, + get_build_dir, + get_ck_root, + get_dispatcher_root, + get_generated_kernels_dir, +) + + +def extract_grouped_conv_declarations(source_file: Path) -> list: + """Extract DECL_GROUPED_CONV_KERNEL_SET declarations from C++ source.""" + content = source_file.read_text() + declarations = [] + + # Pattern: DECL_GROUPED_CONV_KERNEL_SET(name, .add(...).add(...)) + # Find all DECL_GROUPED_CONV_KERNEL_SET blocks by matching parentheses + pattern_start = r"DECL_GROUPED_CONV_KERNEL_SET\s*\(\s*(\w+)\s*," + for match in re.finditer(pattern_start, content): + set_name = match.group(1) + start_pos = match.end() + + # Find matching closing paren by counting parens + paren_count = 1 # We're already inside the first paren + end_pos = start_pos + for i, c in enumerate(content[start_pos:]): + if c == "(": + paren_count += 1 + elif c == ")": + paren_count -= 1 + if paren_count == 0: + end_pos = start_pos + i + break + + set_body = content[start_pos:end_pos] + + # Pattern 1: Simple add("dtype", "layout", "conv_type", tile_k, tile_c) + simple_add = ( + r'\.add\s*\(\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)' + ) + for add_match in re.finditer(simple_add, set_body): + conv_type = add_match.group(3) + default_pipeline = ( + "compv3" if conv_type in ("bwd_data", "bwd_weight") else "compv4" + ) + declarations.append( + { + "set": set_name, + "dtype": add_match.group(1), + "layout": add_match.group(2), + "conv_type": conv_type, + "tile_k": int(add_match.group(4)), + "tile_c": int(add_match.group(5)), + "num_dims": 2, + "pipeline": default_pipeline, + "scheduler": "intrawave", + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "arch": "gfx942", + } + ) + + # Pattern 2: Full ConvSig()/ConvAlgo() specification + # Find all .add( positions that start with ConvSig() + full_add = r"\.add\s*\(\s*ConvSig\(\)" + add_positions = [m.start() for m in re.finditer(full_add, set_body)] + + for pos in add_positions: + # Find matching closing paren by counting parens + paren_count = 0 + in_add = False + end = pos + for i, c in enumerate(set_body[pos:]): + if c == "(": + paren_count += 1 + in_add = True + elif c == ")": + paren_count -= 1 + if in_add and paren_count == 0: + end = pos + i + 1 + break + + add_str = set_body[pos:end] + + # Extract signature part (between ConvSig() and ConvAlgo()) + sig_match = re.search(r"ConvSig\(\)(.*?)ConvAlgo\(\)", add_str, re.DOTALL) + if not sig_match: + continue + sig_str = sig_match.group(1) + + # Extract algorithm part (between ConvAlgo() and arch string) + algo_match = re.search( + r'ConvAlgo\(\)(.*?),\s*"(\w+)"\s*\)', add_str, re.DOTALL + ) + if not algo_match: + continue + algo_str = algo_match.group(1) + arch = algo_match.group(2) + + # Parse signature + dtype = "fp16" + dtype_match = re.search(r'\.dtype\s*\(\s*"(\w+)"', sig_str) + if dtype_match: + dtype = dtype_match.group(1) + + layout = "nhwgc" + layout_match = re.search(r'\.layout\s*\(\s*"(\w+)"', sig_str) + if layout_match: + layout = layout_match.group(1) + + conv_type = "forward" + conv_type_match = re.search(r'\.conv_type\s*\(\s*"(\w+)"', sig_str) + if conv_type_match: + conv_type = conv_type_match.group(1) + + num_dims = 2 + dims_match = re.search(r"\.dims\s*\(\s*(\d+)", sig_str) + if dims_match: + num_dims = int(dims_match.group(1)) + + # Parse algorithm + tile_k, tile_c = 128, 128 + tile_match = re.search( + r"\.tile\s*\(\s*\d+\s*,\s*(\d+)\s*,\s*(\d+)", algo_str + ) + if tile_match: + tile_k = int(tile_match.group(1)) + tile_c = int(tile_match.group(2)) + + wave_m, wave_n, wave_k = 2, 2, 1 + wave_match = re.search( + r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?", algo_str + ) + if wave_match: + wave_m = int(wave_match.group(1)) + wave_n = int(wave_match.group(2)) + wave_k = int(wave_match.group(3) or 1) + + warp_m, warp_n, warp_k = 32, 32, 16 + warp_match = re.search( + r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?", algo_str + ) + if warp_match: + warp_m = int(warp_match.group(1)) + warp_n = int(warp_match.group(2)) + warp_k = int(warp_match.group(3) or 16) + + pipeline = "compv4" + pipeline_match = re.search(r'\.pipeline\s*\(\s*"(\w+)"', algo_str) + if pipeline_match: + pipeline = pipeline_match.group(1) + + scheduler = "intrawave" + scheduler_match = re.search(r'\.scheduler\s*\(\s*"(\w+)"', algo_str) + if scheduler_match: + scheduler = scheduler_match.group(1) + + # Parse additional parameters + vector_a, vector_b, vector_c = 4, 8, 8 + vector_match = re.search( + r"\.vector_sizes\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)", algo_str + ) + if vector_match: + vector_a = int(vector_match.group(1)) + vector_b = int(vector_match.group(2)) + vector_c = int(vector_match.group(3)) + + block_per_cu = 1 + block_per_cu_match = re.search(r"\.block_per_cu\s*\(\s*(\d+)", algo_str) + if block_per_cu_match: + block_per_cu = int(block_per_cu_match.group(1)) + + memory_op = "set" + memory_op_match = re.search(r'\.memory_op\s*\(\s*"(\w+)"', algo_str) + if memory_op_match: + memory_op = memory_op_match.group(1) + + epilogue = "cshuffle" + epilogue_match = re.search(r'\.epilogue\s*\(\s*"(\w+)"', algo_str) + if epilogue_match: + epilogue = epilogue_match.group(1) + + # Parse num_wave_groups (for V5 pipeline) + num_wave_groups = 1 + nwg_match = re.search(r"\.num_wave_groups\s*\(\s*(\d+)", algo_str) + if nwg_match: + num_wave_groups = int(nwg_match.group(1)) + + # Parse num_groups_to_merge (for merged group grouped convolution) + num_groups_to_merge = 1 + ngm_match = re.search(r"\.num_groups_to_merge\s*\(\s*(\d+)", algo_str) + if ngm_match: + num_groups_to_merge = int(ngm_match.group(1)) + + # Parse double_smem_buffer (for V4 pipeline) + double_smem_buffer = False + dsb_match = re.search( + r"\.double_smem_buffer\s*\(\s*(true|false)", algo_str, re.I + ) + if dsb_match: + double_smem_buffer = dsb_match.group(1).lower() == "true" + + # Parse padding flags + pad_m, pad_n, pad_k = True, True, True + padding_match = re.search( + r"\.padding\s*\(\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)", + algo_str, + re.I, + ) + if padding_match: + pad_m = padding_match.group(1).lower() == "true" + pad_n = padding_match.group(2).lower() == "true" + pad_k = padding_match.group(3).lower() == "true" + + declarations.append( + { + "set": set_name, + "dtype": dtype, + "layout": layout, + "conv_type": conv_type, + "tile_k": tile_k, + "tile_c": tile_c, + "num_dims": num_dims, + "pipeline": pipeline, + "scheduler": scheduler, + "wave_m": wave_m, + "wave_n": wave_n, + "wave_k": wave_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "vector_a": vector_a, + "vector_b": vector_b, + "vector_c": vector_c, + "block_per_cu": block_per_cu, + "memory_op": memory_op, + "epilogue": epilogue, + "num_wave_groups": num_wave_groups, + "num_groups_to_merge": num_groups_to_merge, + "double_smem_buffer": double_smem_buffer, + "pad_m": pad_m, + "pad_n": pad_n, + "pad_k": pad_k, + "arch": arch, + } + ) + + return declarations + + +# ============================================================================= +# VALIDATION AND AUTO-CORRECTION +# ============================================================================= + + +def is_grouped_conv_wildcard_declaration(decl: dict) -> bool: + """Check if a declaration uses wildcards (-1 or '*').""" + wildcard_fields = ["wave_m", "wave_n", "warp_m", "warp_n", "pipeline", "scheduler"] + for field in wildcard_fields: + val = decl.get(field) + if val == -1 or val == "*": + return True + return False + + +def validate_grouped_conv_kernel_config(decl: dict, arch: str = "gfx942") -> tuple: + """Validate a grouped conv kernel configuration against known supported combinations. + + Returns: (is_valid, error_message) + """ + # Skip validation for wildcards - expansion will filter invalid combos + if is_grouped_conv_wildcard_declaration(decl): + return (True, None) + + arch_data = get_arch_filter_data() + + pipeline = decl.get("pipeline", "compv4") + scheduler = decl.get("scheduler", "intrawave") + dtype = decl.get("dtype", "fp16") + + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + warp_m = decl.get("warp_m", 32) + warp_n = decl.get("warp_n", 32) + warp_k = decl.get("warp_k", 16) + + errors = [] + + # Check trait combination (pipeline, epilogue, scheduler) + combo = (pipeline, "cshuffle", scheduler) + if combo in arch_data["trait_unsupported"]: + errors.append( + f"Unsupported trait combination: pipeline={pipeline}, scheduler={scheduler}\n" + f" Valid schedulers for {pipeline}: intrawave" + ) + + # Check wave configuration for this arch + warp_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + wave_cfg = [wave_m, wave_n, wave_k] + if wave_cfg not in warp_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_combos) + errors.append( + f"Unsupported wave configuration [{wave_m},{wave_n},{wave_k}] for {arch}\n" + f" Valid wave configs: {valid_str}" + ) + + # Check warp tile configuration for this arch and dtype + acc_dtype = "int32" if dtype == "int8" else "fp32" + dtype_key = f"{dtype}_{dtype}_{acc_dtype}" + warp_tile_combos = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16], [16, 16, 32]]) + ) + warp_cfg = [warp_m, warp_n, warp_k] + if warp_cfg not in warp_tile_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_tile_combos[:5]) + errors.append( + f"Unsupported warp tile [{warp_m},{warp_n},{warp_k}] for {arch}/{dtype}\n" + f" Valid warp tiles: {valid_str}" + ) + + # Check arch is supported + if arch not in arch_data["supported_archs"]: + errors.append( + f"Unsupported architecture: {arch}\n" + f" Supported: {', '.join(arch_data['supported_archs'])}" + ) + + if errors: + return (False, "\n".join(errors)) + + return (True, None) + + +def expand_grouped_conv_declaration_with_arch_filter( + decl: dict, arch: str = "gfx942" +) -> list: + """Expand a grouped conv declaration with wildcards into valid configurations. + + Wildcards: + - wave_m/wave_n = -1: Try all valid wave configs for this arch + - warp_m/warp_n = -1: Try all valid warp tiles for this arch/dtype + - pipeline/scheduler = "*": Try all valid combinations + + Returns a list of fully-specified declarations. + """ + arch_data = get_arch_filter_data() + dtype = decl.get("dtype", "fp16") + + # Get valid combinations for this arch + valid_wave_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + acc_dtype = "int32" if dtype == "int8" else "fp32" + dtype_key = f"{dtype}_{dtype}_{acc_dtype}" + valid_warp_tiles = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + + # Valid pipelines and schedulers + valid_pipelines = ["compv3", "compv4"] + valid_schedulers = ["intrawave"] # interwave often unsupported + + # Determine which fields need expansion + expand_wave = decl.get("wave_m", 2) == -1 or decl.get("wave_n", 2) == -1 + expand_warp = decl.get("warp_m", 32) == -1 or decl.get("warp_n", 32) == -1 + expand_pipeline = decl.get("pipeline", "compv4") == "*" + expand_scheduler = decl.get("scheduler", "intrawave") == "*" + + # Build combinations + wave_options = ( + valid_wave_combos + if expand_wave + else [[decl.get("wave_m", 2), decl.get("wave_n", 2), decl.get("wave_k", 1)]] + ) + warp_options = ( + valid_warp_tiles + if expand_warp + else [[decl.get("warp_m", 32), decl.get("warp_n", 32), decl.get("warp_k", 16)]] + ) + pipeline_options = ( + valid_pipelines if expand_pipeline else [decl.get("pipeline", "compv4")] + ) + scheduler_options = ( + valid_schedulers if expand_scheduler else [decl.get("scheduler", "intrawave")] + ) + + expanded = [] + for wave in wave_options: + for warp in warp_options: + for pipeline in pipeline_options: + for scheduler in scheduler_options: + # Skip known invalid combinations + if (pipeline, "cshuffle", scheduler) in arch_data[ + "trait_unsupported" + ]: + continue + + new_decl = decl.copy() + new_decl["wave_m"] = wave[0] + new_decl["wave_n"] = wave[1] + new_decl["wave_k"] = wave[2] + new_decl["warp_m"] = warp[0] + new_decl["warp_n"] = warp[1] + new_decl["warp_k"] = warp[2] + new_decl["pipeline"] = pipeline + new_decl["scheduler"] = scheduler + + expanded.append(new_decl) + + # If no valid expansions, return original (will fail validation later) + if not expanded: + return [decl] + + # Return first valid config (or all if needed) + return expanded[:1] # Just use first valid config for grouped conv + + +def validate_and_expand_grouped_conv_declarations( + declarations: list, arch: str, verbose: bool = False +) -> list: + """Validate declarations and auto-correct invalid ones via wildcard expansion.""" + print(f"\n Validating against {arch} arch filter...") + + wildcard_count = 0 + invalid_count = 0 + auto_corrections = [] + + for decl in declarations: + decl_arch = decl.get("arch", arch) + decl_name = ( + f"{decl['dtype']}_{decl['conv_type']}_{decl['tile_k']}x{decl['tile_c']}" + ) + + # Check for wildcards + if is_grouped_conv_wildcard_declaration(decl): + wildcard_count += 1 + continue + + is_valid, error_msg = validate_grouped_conv_kernel_config(decl, decl_arch) + if not is_valid: + print(f"\n WARNING Invalid grouped conv configuration: {decl_name}") + + # Parse the error and show specific auto-corrections + corrections = [] + original_values = {} + + if "wave configuration" in error_msg.lower(): + original_values["wave"] = ( + f"[{decl.get('wave_m', 2)}, {decl.get('wave_n', 2)}, {decl.get('wave_k', 1)}]" + ) + decl["wave_m"] = -1 + decl["wave_n"] = -1 + corrections.append( + f"wave: {original_values['wave']} -> [wildcard expansion]" + ) + + if "warp tile" in error_msg.lower(): + original_values["warp"] = ( + f"[{decl.get('warp_m', 32)}, {decl.get('warp_n', 32)}, {decl.get('warp_k', 16)}]" + ) + decl["warp_m"] = -1 + decl["warp_n"] = -1 + corrections.append( + f"warp_tile: {original_values['warp']} -> [wildcard expansion]" + ) + + if "trait combination" in error_msg.lower(): + original_values["pipeline"] = decl.get("pipeline", "compv4") + original_values["scheduler"] = decl.get("scheduler", "intrawave") + decl["pipeline"] = "*" + decl["scheduler"] = "*" + corrections.append( + f"pipeline: {original_values['pipeline']} -> [wildcard expansion]" + ) + corrections.append( + f"scheduler: {original_values['scheduler']} -> [wildcard expansion]" + ) + + # Print the auto-corrections + print(" AUTO-CORRECTION:") + for corr in corrections: + print(f" - {corr}") + auto_corrections.append((decl_name, corrections)) + + invalid_count += 1 + wildcard_count += 1 + + if invalid_count > 0: + print( + f"\n WARNING {invalid_count} invalid config(s) auto-corrected via wildcard expansion" + ) + + if wildcard_count > 0: + print( + f" OK {len(declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" + ) + else: + print(f" OK All {len(declarations)} configurations valid") + + # Expand wildcards + print("\n Expanding wildcards to valid configurations...") + expanded_declarations = [] + for decl in declarations: + decl_arch = decl.get("arch", arch) + decl_name = ( + f"{decl['dtype']}_{decl['conv_type']}_{decl['tile_k']}x{decl['tile_c']}" + ) + + expanded = expand_grouped_conv_declaration_with_arch_filter(decl, decl_arch) + expanded_declarations.extend(expanded) + + if len(expanded) > 1: + print( + f" {decl_name}: expanded to {len(expanded)} valid configurations" + ) + for exp in expanded[:3]: + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print( + f" -> wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}" + ) + if len(expanded) > 3: + print(f" ... and {len(expanded) - 3} more") + elif is_grouped_conv_wildcard_declaration(decl) and len(expanded) == 1: + exp = expanded[0] + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print(f" {decl_name}: -> wave={wave_str}, warp={warp_str}") + + if len(expanded_declarations) != len(declarations): + print( + f"\n Total: {len(declarations)} declarations -> {len(expanded_declarations)} configurations" + ) + + return expanded_declarations + + +def _generate_single_grouped_conv_kernel(args: tuple) -> tuple: + """Generate one grouped conv kernel (picklable for ProcessPoolExecutor). + + Args: (decl, output_dir_str, gpu_target) + Returns: (idx, filepath_str or None, error_str or None) + """ + decl, output_dir_str, gpu_target = args + output_dir = Path(output_dir_str) + idx = decl.get("_idx", 0) + + try: + from codegen_common import TileConfig + from unified_grouped_conv_codegen import ( + GroupedConvKernelConfig, + GroupedConvTraitConfig, + GroupedConvVariant, + UnifiedGroupedConvCodegen, + ) + + # Map conv_type to variant + variant = GroupedConvVariant.FORWARD + if decl["conv_type"] == "bwd_data": + variant = GroupedConvVariant.BACKWARD_DATA + elif decl["conv_type"] == "bwd_weight": + variant = GroupedConvVariant.BACKWARD_WEIGHT + + pipeline = decl.get("pipeline", "compv4") + adj_tile_k = 64 * 2 if pipeline == "compv4" else 64 + + # Create tile config (tile_m=tile_k, tile_n=tile_c for conv GEMM view) + tile = TileConfig( + tile_m=decl["tile_k"], + tile_n=decl["tile_c"], + tile_k=adj_tile_k, + warp_m=decl["wave_m"], + warp_n=decl["wave_n"], + warp_k=decl.get("wave_k", 1), + warp_tile_m=decl["warp_m"], + warp_tile_n=decl["warp_n"], + warp_tile_k=decl["warp_k"], + ) + + trait = GroupedConvTraitConfig( + pipeline=pipeline, + scheduler=decl["scheduler"], + epilogue=decl.get("epilogue", "cshuffle"), + double_smem_buffer=decl.get("double_smem_buffer", False), + pad_m=decl.get("pad_m", True), + pad_n=decl.get("pad_n", True), + pad_k=decl.get("pad_k", True), + num_groups_to_merge=decl.get("num_groups_to_merge", 1), + ) + + config = GroupedConvKernelConfig( + tile=tile, + trait=trait, + variant=variant, + ndim_spatial=decl["num_dims"], + arch=decl.get("arch", gpu_target), + vector_size_a=decl.get("vector_a", 4), + vector_size_b=decl.get("vector_b", 8), + vector_size_c=decl.get("vector_c", 8), + block_per_cu=decl.get("block_per_cu", 1), + num_wave_groups=decl.get("num_wave_groups", 1), + num_groups_to_merge=decl.get("num_groups_to_merge", 1), + double_smem_buffer=decl.get("double_smem_buffer", False), + ) + + codegen = UnifiedGroupedConvCodegen(output_dir, gpu_target=gpu_target) + kernel_path, _ = codegen.generate_kernel(config, decl["dtype"], variant) + return (idx, str(kernel_path), None) + + except Exception as e: + return (idx, None, str(e)) + + +def generate_grouped_conv_kernels( + declarations: list, + output_dir: Path, + gpu_target: str = "gfx942", + max_workers: Optional[int] = None, +) -> list: + """Generate grouped convolution kernels using unified_grouped_conv_codegen. + + Uses ProcessPoolExecutor for parallel kernel generation. + """ + output_dir.mkdir(parents=True, exist_ok=True) + + # Prepare work items (add _idx for ordering) + work_items = [] + for idx, decl in enumerate(declarations): + decl_copy = decl.copy() + decl_copy["_idx"] = idx + work_items.append((decl_copy, str(output_dir), gpu_target)) + + max_workers = max_workers or min(len(work_items), os.cpu_count() or 4) + generated = [] + failed = [] + + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(_generate_single_grouped_conv_kernel, w): w[0]["_idx"] + for w in work_items + } + for future in as_completed(futures): + idx, path, err = future.result() + if path: + generated.append(Path(path)) + print_info(f" Generated: {Path(path).name}") + else: + failed.append((idx, err)) + print_error(f" Failed kernel {idx + 1}: {err}") + + if failed: + for idx, err in failed[:3]: + print_error(f" Kernel {idx + 1}: {err[:200]}") + if len(failed) > 3: + print_error(f" ... and {len(failed) - 3} more failures") + + return generated + + +def compile_grouped_conv_example( + source_file: Path, + output_bin: Path, + kernel_headers: list, + hipcc: str, + gpu_target: str, +) -> bool: + """Compile the C++ example with generated kernels.""" + kernel_dir = get_generated_kernels_dir() + ck_root = get_ck_root() + dispatcher_dir = get_dispatcher_root() + + includes = [ + f"-I{ck_root / 'include'}", + f"-I{dispatcher_dir / 'include'}", + f"-I{kernel_dir}", + ] + + # Build include flags for generated kernels + kernel_includes = [] + for header in kernel_headers: + kernel_includes.extend(["-include", str(header)]) + + # Add define to indicate kernels are available + defines = ["-DGROUPED_CONV_KERNEL_AVAILABLE=1"] + + cmd = [ + hipcc, + "-std=c++20", + "-O2", + f"--offload-arch={gpu_target}", + *includes, + *defines, + *kernel_includes, + "-o", + str(output_bin), + str(source_file), + ] + + print_info(f" Compiling: {source_file.name}") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + if result.stderr: + lines = result.stderr.split("\n") + errors = [line for line in lines if "error:" in line.lower()][:5] + for err_line in errors: + print_error(f" {err_line}") + return False + + return True + + +def main(): + parser = argparse.ArgumentParser( + description="Build C++ grouped convolution example with self-contained kernel generation" + ) + parser.add_argument("source", help="Source file (.cpp)") + parser.add_argument("--output", "-o", help="Output binary name") + parser.add_argument("--gpu-target", default="gfx942", help="GPU target") + parser.add_argument( + "--no-compile", action="store_true", help="Only generate kernels, don't compile" + ) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument( + "--jobs", + "-j", + type=int, + default=None, + help="Parallel jobs for kernel generation (default: cpu_count)", + ) + args = parser.parse_args() + + # Resolve source file + source_file = Path(args.source) + if not source_file.is_absolute(): + candidates = [ + get_dispatcher_root() / args.source, + Path.cwd() / args.source, + ] + for c in candidates: + if c.exists(): + source_file = c + break + + if not source_file.exists(): + print_error(f"Source file not found: {source_file}") + return 1 + + build_dir = get_build_dir() + kernel_dir = get_generated_kernels_dir() + output_name = args.output or source_file.stem + output_bin = build_dir / output_name + + print_success("=== Grouped Conv Example Builder (Self-Contained) ===") + + # Phase 1: Extract declarations + print_phase(1, "Scanning for DECL_GROUPED_CONV_KERNEL_SET...") + declarations = extract_grouped_conv_declarations(source_file) + + if not declarations: + print_error(" No DECL_GROUPED_CONV_KERNEL_SET declarations found!") + return 1 + + print(f" Found {len(declarations)} kernel declaration(s):") + for decl in declarations: + name = f"{decl['dtype']}_{decl['conv_type']}_{decl['num_dims']}d_{decl['tile_k']}x{decl['tile_c']}" + print(f" [{decl['set']}] {name}") + + # Phase 2: Validate and expand + print_phase(2, "Validating and expanding declarations...") + declarations = validate_and_expand_grouped_conv_declarations( + declarations, args.gpu_target, args.verbose + ) + print() + + # Phase 3: Generate kernels + print_phase(3, "Generating kernels...") + generated = generate_grouped_conv_kernels( + declarations, kernel_dir, args.gpu_target, max_workers=args.jobs + ) + + if not generated: + print_error(" No kernels generated!") + return 1 + + print(f" Generated {len(generated)} kernel file(s)") + print() + + # Phase 4: Compile (optional) + if args.no_compile: + print_info("Skipping compilation (--no-compile)") + print() + print_success("=== Kernel Generation Complete ===") + print(f"Kernels in: {kernel_dir}") + return 0 + + print_phase(4, "Compiling example...") + hipcc_path = find_hipcc() + + if not hipcc_path: + print_error(" hipcc not found. Install ROCm or set HIPCC env var.") + print(" To compile manually:") + ck_root = get_dispatcher_root().parent + print( + f" hipcc -std=c++20 -O2 -I{ck_root / 'include'} -I{get_dispatcher_root() / 'include'} \\" + ) + print(f" -I{kernel_dir} \\") + for h in generated[:1]: + print(f" -include {h} \\") + print(" -DGROUPED_CONV_KERNEL_AVAILABLE=1 \\") + print(f" --offload-arch={args.gpu_target} \\") + print(f" {source_file} -o {output_bin}") + return 1 + + build_dir.mkdir(parents=True, exist_ok=True) + + if not compile_grouped_conv_example( + source_file, output_bin, generated, hipcc_path, args.gpu_target + ): + print_error(" Compilation failed!") + return 1 + + print_success(f" Output: {output_bin}") + print() + + print_success("=== Build Complete ===") + print() + print("Run with:") + print(f" {output_bin}") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/scripts/example_kernel_builder.py b/dispatcher/scripts/example_kernel_builder.py index d3bb619174..20952cd91f 100755 --- a/dispatcher/scripts/example_kernel_builder.py +++ b/dispatcher/scripts/example_kernel_builder.py @@ -55,10 +55,10 @@ def extract_balanced_parens(text: str, start_pos: int) -> str: def parse_conv_declarations(content: str) -> List[Dict]: - """Parse DECL_CONV_KERNEL_SET declarations with all parameters.""" + """Parse DECL_GROUPED_CONV_KERNEL_SET declarations with all parameters.""" kernels = [] - for match in re.finditer(r"DECL_CONV_KERNEL_SET\s*\(", content): + for match in re.finditer(r"DECL_GROUPED_CONV_KERNEL_SET\s*\(", content): body = extract_balanced_parens(content, match.end() - 1) if not body: continue @@ -619,7 +619,7 @@ def strip_cpp_strings_and_comments(content: str) -> str: n = len(content) # Patterns that indicate a string is problematic and should be stripped - problematic_patterns = ["DECL_KERNEL_SET", "DECL_CONV_KERNEL_SET", ".add("] + problematic_patterns = ["DECL_KERNEL_SET", "DECL_GROUPED_CONV_KERNEL_SET", ".add("] while i < n: # Check for raw string literal: R"delimiter(...)delimiter" @@ -697,7 +697,7 @@ def detect_and_parse(source_path: Path) -> Tuple[str, List[Dict]]: content = source_path.read_text() content = strip_cpp_strings_and_comments(content) - if "DECL_CONV_KERNEL_SET" in content: + if "DECL_GROUPED_CONV_KERNEL_SET" in content: return "conv", parse_conv_declarations(content) elif "DECL_KERNEL_SET" in content: return "gemm", parse_gemm_declarations(content) @@ -966,30 +966,128 @@ def generate_per_set_functions(source_stem: str) -> str: def generate_conv_registration( kernel_headers: List[Path], example_name: str, kernels: List[Dict] ) -> str: - """Generate Conv kernel registration code for the dispatcher registry.""" + """Generate Conv kernel registration code for the dispatcher registry. + + Creates real GroupedConvKernelInstance entries backed by the generated + launcher's launch() method via the conv backend RunFn factories. + """ if not kernel_headers: return " // No kernels to register" lines = [] - lines.append( - " (void)registry; (void)arch; // Conv uses direct launcher pattern for now" - ) - # For conv, we provide direct access to kernel launchers for i, h in enumerate(kernel_headers): - kernel_name = h.stem - lines.append(f" // Kernel {i + 1}: {kernel_name}") + kname = h.stem + ns = f"ns_{kname}" + launcher = f"{ns}::{kname}_Launcher" + + # Determine direction and ndim from the kernel header name + if "_fwd_" in kname: + direction = "Forward" + run_fn_factory = "make_conv_fwd_run_fn" + elif "_bwd_data_" in kname or "_bwdd_" in kname: + direction = "BackwardData" + run_fn_factory = "make_conv_bwd_data_run_fn" + elif "_bwd_weight_" in kname or "_bwdw_" in kname: + direction = "BackwardWeight" + run_fn_factory = "make_conv_bwd_weight_run_fn" + else: + direction = "Forward" + run_fn_factory = "make_conv_fwd_run_fn" + + ndim = 3 if "_3d_" in kname else 2 + + # Parse dtype from name (e.g. grouped_conv_fwd_fp16_...) + dtype = "fp16" + for dt in ["fp16", "bf16", "fp32"]: + if f"_{dt}_" in kname: + dtype = dt + break + + # Parse tile, wave, warp from name. + # Format: ..._TILExTILExTILE_WAVExWAVExWAVE_WARPxWARPxWARP_... + import re as _re + + tile_m, tile_n, tile_k = 1, 128, 128 + wave_m, wave_n, wave_k = 2, 2, 1 + warp_m, warp_n, warp_k = 32, 32, 16 + + triplets = _re.findall(r"_(\d+)x(\d+)x(\d+)", kname) + if len(triplets) >= 1: + tile_m, tile_n, tile_k = ( + int(triplets[0][0]), + int(triplets[0][1]), + int(triplets[0][2]), + ) + if len(triplets) >= 2: + wave_m, wave_n, wave_k = ( + int(triplets[1][0]), + int(triplets[1][1]), + int(triplets[1][2]), + ) + if len(triplets) >= 3: + warp_m, warp_n, warp_k = ( + int(triplets[2][0]), + int(triplets[2][1]), + int(triplets[2][2]), + ) + + pipeline = "compv4" if "compv4" in kname else "compv3" + scheduler = "interwave" if "interwave" in kname else "intrawave" + epilogue = "cshuffle" if "cshuffle" in kname else "default" + + # ConvConfigBase defaults + vec_a, vec_b, vec_c = 4, 8, 8 + block_per_cu = 1 + num_wave_groups = 1 + num_groups_to_merge = 1 + + lines.append(f" // Kernel {i + 1}: {kname}") + lines.append(" {") + lines.append(f" ck_tile::dispatcher::GroupedConvKernelKey key_{i};") + lines.append(f' key_{i}.dtype_in = "{dtype}";') + lines.append(f' key_{i}.dtype_wei = "{dtype}";') + lines.append(f' key_{i}.dtype_out = "{dtype}";') + lines.append(f' key_{i}.layout = "nhwgc";') + lines.append(f" key_{i}.ndim_spatial = {ndim};") + lines.append( + f" key_{i}.op = ck_tile::dispatcher::GroupedConvOp::{direction};" + ) + lines.append(f" key_{i}.tile_m = {tile_m};") + lines.append(f" key_{i}.tile_n = {tile_n};") + lines.append(f" key_{i}.tile_k = {tile_k};") + lines.append(f" key_{i}.wave_m = {wave_m};") + lines.append(f" key_{i}.wave_n = {wave_n};") + lines.append(f" key_{i}.wave_k = {wave_k};") + lines.append(f" key_{i}.warp_m = {warp_m};") + lines.append(f" key_{i}.warp_n = {warp_n};") + lines.append(f" key_{i}.warp_k = {warp_k};") + lines.append(f' key_{i}.pipeline = "{pipeline}";') + lines.append(f' key_{i}.scheduler = "{scheduler}";') + lines.append(f' key_{i}.epilogue = "{epilogue}";') + lines.append(f" key_{i}.vector_size_a = {vec_a};") + lines.append(f" key_{i}.vector_size_b = {vec_b};") + lines.append(f" key_{i}.vector_size_c = {vec_c};") + lines.append(f" key_{i}.block_per_cu = {block_per_cu};") + lines.append(f" key_{i}.num_wave_groups = {num_wave_groups};") + lines.append(f" key_{i}.num_groups_to_merge = {num_groups_to_merge};") + lines.append(f" key_{i}.arch = arch;") + lines.append( + f" auto run_fn_{i} = ck_tile::dispatcher::backends::{run_fn_factory}<{launcher}, {ndim}>();" + ) + lines.append( + f' auto inst_{i} = std::make_shared(key_{i}, "{kname}", std::move(run_fn_{i}));' + ) + lines.append(f" registry.register_kernel(key_{i}, inst_{i});") + lines.append(" }") return "\n".join(lines) -def generate_conv_kernels( - kernels: List[Dict], output_dir: Path, codegen_dir: Path -) -> bool: - """Generate Conv kernels for ALL declarations using unified codegen.""" - if not kernels: - return False - +def _build_conv_codegen_cmd( + idx: int, k: Dict, codegen_dir: Path, output_dir: Path +) -> Tuple[int, List[str], str]: + """Build the command for a single conv kernel codegen invocation.""" variant_map = { "forward": "forward", "bwd_data": "bwd_data", @@ -997,93 +1095,130 @@ def generate_conv_kernels( "bwd_weight": "bwd_weight", "backward_weight": "bwd_weight", } + variant = variant_map.get(k.get("conv_type", "forward"), "forward") + + cmd = [ + sys.executable, + str(codegen_dir / "unified_grouped_conv_codegen.py"), + "--datatype", + k.get("dtype", "fp16"), + "--variant", + variant, + "--ndim", + str(k.get("ndim", 2)), + "--output", + str(output_dir), + ] + + if k.get("tile_m"): + cmd.extend(["--tile-m", str(k["tile_m"])]) + if k.get("tile_n"): + cmd.extend(["--tile-n", str(k["tile_n"])]) + if k.get("warp_m"): + cmd.extend(["--warp-m", str(k["warp_m"])]) + if k.get("warp_n"): + cmd.extend(["--warp-n", str(k["warp_n"])]) + if k.get("warp_k"): + cmd.extend(["--warp-k", str(k["warp_k"])]) + if k.get("warp_tile_m"): + cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])]) + if k.get("warp_tile_n"): + cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])]) + if k.get("warp_tile_k"): + cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])]) + if k.get("pipeline"): + cmd.extend(["--pipeline", k["pipeline"]]) + if k.get("scheduler"): + cmd.extend(["--scheduler", k["scheduler"]]) + if k.get("epilogue"): + cmd.extend(["--epilogue", k["epilogue"]]) + if k.get("vector_a"): + cmd.extend(["--vector-a", str(k["vector_a"])]) + if k.get("vector_b"): + cmd.extend(["--vector-b", str(k["vector_b"])]) + if k.get("vector_c"): + cmd.extend(["--vector-c", str(k["vector_c"])]) + if k.get("block_per_cu"): + cmd.extend(["--block-per-cu", str(k["block_per_cu"])]) + if k.get("num_wave_groups"): + cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])]) + if k.get("num_groups_to_merge"): + cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])]) + if k.get("double_smem_buffer") is not None: + cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()]) + if k.get("tile_k"): + cmd.extend(["--tile-k", str(k["tile_k"])]) + + return (idx, cmd, str(codegen_dir)) + + +def _run_conv_codegen(args: Tuple) -> Tuple[int, bool, str]: + """Run unified_grouped_conv_codegen.py for a single kernel config (picklable for ProcessPoolExecutor).""" + idx, cmd, cwd = args + result = subprocess.run(cmd, capture_output=True, text=True, cwd=cwd) + if result.returncode != 0: + return (idx, False, result.stderr[:300]) + return (idx, True, "") + + +def generate_conv_kernels( + kernels: List[Dict], output_dir: Path, codegen_dir: Path +) -> bool: + """Generate Conv kernels for ALL declarations using unified codegen. + + Launches all codegen subprocesses in parallel via ProcessPoolExecutor + for significantly faster generation when multiple conv kernels are declared. + """ + if not kernels: + return False + + work_items = [ + _build_conv_codegen_cmd(idx, k, codegen_dir, output_dir) + for idx, k in enumerate(kernels) + ] success_count = 0 + max_workers = min(len(work_items), os.cpu_count() or 4) - # Generate a kernel for EACH declaration - for idx, k in enumerate(kernels): - variant = variant_map.get(k.get("conv_type", "forward"), "forward") - - cmd = [ - sys.executable, - str(codegen_dir / "unified_conv_codegen.py"), - "--datatype", - k.get("dtype", "fp16"), - "--variant", - variant, - "--ndim", - str(k.get("ndim", 2)), - "--output", - str(output_dir), - ] - - # Add optional parameters if specified - if k.get("tile_m"): - cmd.extend(["--tile-m", str(k["tile_m"])]) - if k.get("tile_n"): - cmd.extend(["--tile-n", str(k["tile_n"])]) - if k.get("warp_m"): - cmd.extend(["--warp-m", str(k["warp_m"])]) - if k.get("warp_n"): - cmd.extend(["--warp-n", str(k["warp_n"])]) - if k.get("warp_k"): - cmd.extend(["--warp-k", str(k["warp_k"])]) - if k.get("warp_tile_m"): - cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])]) - if k.get("warp_tile_n"): - cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])]) - if k.get("warp_tile_k"): - cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])]) - if k.get("pipeline"): - cmd.extend(["--pipeline", k["pipeline"]]) - if k.get("scheduler"): - cmd.extend(["--scheduler", k["scheduler"]]) - if k.get("epilogue"): - cmd.extend(["--epilogue", k["epilogue"]]) - if k.get("vector_a"): - cmd.extend(["--vector-a", str(k["vector_a"])]) - if k.get("vector_b"): - cmd.extend(["--vector-b", str(k["vector_b"])]) - if k.get("vector_c"): - cmd.extend(["--vector-c", str(k["vector_c"])]) - if k.get("block_per_cu"): - cmd.extend(["--block-per-cu", str(k["block_per_cu"])]) - if k.get("num_wave_groups"): - cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])]) - if k.get("num_groups_to_merge"): - cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])]) - if k.get("double_smem_buffer") is not None: - cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()]) - if k.get("tile_k"): - cmd.extend(["--tile-k", str(k["tile_k"])]) - - result = subprocess.run( - cmd, capture_output=True, text=True, cwd=str(codegen_dir) - ) - if result.returncode != 0: - print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}") - else: - success_count += 1 + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(_run_conv_codegen, w): w[0] for w in work_items} + for future in as_completed(futures): + idx, ok, err = future.result() + if ok: + success_count += 1 + else: + print(f" Codegen error for kernel {idx + 1}: {err}") return success_count > 0 +def _run_gemm_codegen(args: Tuple) -> Tuple[int, bool, str]: + """Run unified_gemm_codegen.py for a single kernel config (picklable for ProcessPoolExecutor).""" + idx, cmd, cwd = args + result = subprocess.run(cmd, capture_output=True, text=True, cwd=cwd) + if result.returncode != 0: + return (idx, False, result.stderr[:300]) + return (idx, True, "") + + def generate_gemm_kernels( kernels: List[Dict], output_dir: Path, codegen_dir: Path ) -> bool: - """Generate GEMM kernels for ALL declarations using unified codegen.""" + """Generate GEMM kernels for ALL declarations using unified codegen. + + Launches all codegen subprocesses in parallel via ProcessPoolExecutor + for significantly faster generation when multiple kernels are declared. + """ import json if not kernels: return False - success_count = 0 - - # Generate a kernel for EACH declaration + # Build all commands upfront + work_items = [] for idx, k in enumerate(kernels): variant = "multi_d" if k.get("elementwise_op") else "standard" - # Build tile config JSON for this specific kernel tile_config = { "tile_m": [k.get("tile_m", 128)], "tile_n": [k.get("tile_n", 128)], @@ -1125,13 +1260,20 @@ def generate_gemm_kernels( config_json, ] - result = subprocess.run( - cmd, capture_output=True, text=True, cwd=str(codegen_dir) - ) - if result.returncode != 0: - print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}") - else: - success_count += 1 + work_items.append((idx, cmd, str(codegen_dir))) + + # Run all codegen subprocesses in parallel + success_count = 0 + max_workers = min(len(work_items), os.cpu_count() or 4) + + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(_run_gemm_codegen, w): w[0] for w in work_items} + for future in as_completed(futures): + idx, ok, err = future.result() + if ok: + success_count += 1 + else: + print(f" Codegen error for kernel {idx + 1}: {err}") return success_count > 0 @@ -1229,15 +1371,17 @@ def main(): if example_type == "gemm": kernel_headers = list(args.output_dir.glob("gemm_*.hpp")) else: - k = kernels[0] if kernels else {} - variant = k.get("conv_type", "forward") prefix_map = { - "forward": "conv_fwd", - "bwd_data": "conv_bwdd", - "bwd_weight": "conv_bwdw", + "forward": "grouped_conv_fwd", + "bwd_data": "grouped_conv_bwd_data", + "bwd_weight": "grouped_conv_bwd_weight", } - prefix = prefix_map.get(variant, "conv_fwd") - kernel_headers = list(args.output_dir.glob(f"{prefix}_*.hpp")) + # Collect headers from ALL variants present in declarations + variants_used = set(k.get("conv_type", "forward") for k in kernels) + kernel_headers = [] + for variant in variants_used: + prefix = prefix_map.get(variant, "grouped_conv_fwd") + kernel_headers.extend(args.output_dir.glob(f"{prefix}_*.hpp")) if not kernel_headers: print(f"[{target_name}] No kernel headers generated!") @@ -1347,29 +1491,39 @@ def main(): ) if has_bwd_data: - bwdd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdd_") - if bwdd_kernel: - bwdd_ns = f"ns_{bwdd_kernel.stem}" - launcher_aliases.append( - f"using BwdDataKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;" + bwd_data_kernel = find_kernel_by_dtype_type( + kernel_headers, "fp16", "_bwd_data_" + ) + if not bwd_data_kernel: + bwd_data_kernel = find_kernel_by_dtype_type( + kernel_headers, "fp16", "_bwdd_" ) - if not has_fwd: # If no fwd, use bwd_data as first + if bwd_data_kernel: + bwd_data_ns = f"ns_{bwd_data_kernel.stem}" + launcher_aliases.append( + f"using BwdDataKernelLauncher = {bwd_data_ns}::{bwd_data_kernel.stem}_Launcher;" + ) + if not has_fwd: launcher_aliases.append( - f"using FirstKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;" + f"using FirstKernelLauncher = {bwd_data_ns}::{bwd_data_kernel.stem}_Launcher;" ) if has_bwd_weight: - bwdw_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdw_") - if bwdw_kernel: - bwdw_ns = f"ns_{bwdw_kernel.stem}" - launcher_aliases.append( - f"using BwdWeightKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;" + bwd_weight_kernel = find_kernel_by_dtype_type( + kernel_headers, "fp16", "_bwd_weight_" + ) + if not bwd_weight_kernel: + bwd_weight_kernel = find_kernel_by_dtype_type( + kernel_headers, "fp16", "_bwdw_" ) - if ( - not has_fwd and not has_bwd_data - ): # If no fwd or bwdd, use bwdw as first + if bwd_weight_kernel: + bwd_weight_ns = f"ns_{bwd_weight_kernel.stem}" + launcher_aliases.append( + f"using BwdWeightKernelLauncher = {bwd_weight_ns}::{bwd_weight_kernel.stem}_Launcher;" + ) + if not has_fwd and not has_bwd_data: launcher_aliases.append( - f"using FirstKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;" + f"using FirstKernelLauncher = {bwd_weight_ns}::{bwd_weight_kernel.stem}_Launcher;" ) launcher_section = "\n".join(launcher_aliases) @@ -1382,14 +1536,16 @@ def main(): #include "ck_tile/dispatcher/registry.hpp" #include "ck_tile/dispatcher/kernel_instance.hpp" #include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" +#include "ck_tile/dispatcher/backends/generated_conv_backend.hpp" namespace generated {{ // Kernel launchers for direct use {launcher_section} -// Registration function -inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{ +// Registration function (takes GroupedConvRegistry for conv kernels) +inline void {func_name}(ck_tile::dispatcher::GroupedConvRegistry& registry, const std::string& arch) {{ {register_body} }} @@ -1439,7 +1595,7 @@ inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::stri """ header_path.write_text(header_content) - print(f"[{target_name}] ✓ {len(obj_files)} kernels compiled") + print(f"[{target_name}] OK {len(obj_files)} kernels compiled") return 0 diff --git a/dispatcher/scripts/generate_conv_dispatch_header.py b/dispatcher/scripts/generate_conv_dispatch_header.py new file mode 100644 index 0000000000..55cc085ed9 --- /dev/null +++ b/dispatcher/scripts/generate_conv_dispatch_header.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +"""Generate the conv_python_dispatch.hpp header for the Python conv library. + +Reads the include_all headers to find available kernels and creates dispatch +aliases for 2D/3D x fwd/bwd_data/bwd_weight. +""" + +import argparse +import re +from pathlib import Path + + +def find_3d_launcher(include_all_path: Path, variant_prefix: str) -> str: + """Find first 3D launcher name from an include_all header.""" + text = include_all_path.read_text() + pattern = rf"(grouped_conv_{variant_prefix}_\w+_3d_\w+)\.hpp" + match = re.search(pattern, text) + if match: + return match.group(1) + "_Launcher" + return "" + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--kernel-dir", required=True) + parser.add_argument("--output", required=True) + args = parser.parse_args() + + kdir = Path(args.kernel_dir) + + fwd_3d = find_3d_launcher(kdir / "include_all_grouped_conv_fwd_kernels.hpp", "fwd") + bwd_data_3d = find_3d_launcher( + kdir / "include_all_grouped_conv_bwd_data_kernels.hpp", "bwd_data" + ) + bwd_weight_3d = find_3d_launcher( + kdir / "include_all_grouped_conv_bwd_weight_kernels.hpp", "bwd_weight" + ) + + lines = [ + "// Auto-generated dispatch header for Python conv library", + "#pragma once", + "", + "// Forward kernels", + '#include "include_all_grouped_conv_fwd_kernels.hpp"', + "#define CONV_FWD_2D_AVAILABLE 1", + ] + if fwd_3d: + lines += [ + "#define CONV_FWD_3D_AVAILABLE 1", + f"using ConvFwd3dLauncher = {fwd_3d};", + ] + lines += [ + "", + "// Backward data kernels", + '#include "include_all_grouped_conv_bwd_data_kernels.hpp"', + "#define CONV_BWD_DATA_2D_AVAILABLE 1", + ] + if bwd_data_3d: + lines += [ + "#define CONV_BWD_DATA_3D_AVAILABLE 1", + f"using ConvBwdData3dLauncher = {bwd_data_3d};", + ] + lines += [ + "", + "// Backward weight kernels", + '#include "include_all_grouped_conv_bwd_weight_kernels.hpp"', + "#define CONV_BWD_WEIGHT_2D_AVAILABLE 1", + ] + if bwd_weight_3d: + lines += [ + "#define CONV_BWD_WEIGHT_3D_AVAILABLE 1", + f"using ConvBwdWeight3dLauncher = {bwd_weight_3d};", + ] + + # Kernel name table for Python introspection + names = [] + if True: # fwd 2D always present + names.append('"fwd_2d"') + if fwd_3d: + names.append('"fwd_3d"') + if True: # bwd_data 2D + names.append('"bwd_data_2d"') + if bwd_data_3d: + names.append('"bwd_data_3d"') + if True: # bwd_weight 2D + names.append('"bwd_weight_2d"') + if bwd_weight_3d: + names.append('"bwd_weight_3d"') + + lines += [ + "", + "// Kernel inventory for Python", + f"static const char* CONV_KERNEL_NAMES[] = {{{', '.join(names)}}};", + f"static const int CONV_KERNEL_COUNT = {len(names)};", + "", + ] + + Path(args.output).write_text("\n".join(lines) + "\n") + print(f"Generated dispatch header: {args.output} ({len(names)} kernels)") + + +if __name__ == "__main__": + main() diff --git a/dispatcher/scripts/parallel_kernel_builder.py b/dispatcher/scripts/parallel_kernel_builder.py index 911ea61bd7..aef8f4ff0b 100755 --- a/dispatcher/scripts/parallel_kernel_builder.py +++ b/dispatcher/scripts/parallel_kernel_builder.py @@ -132,7 +132,7 @@ def main(): print(f"Linking failed: {result.stderr}") return 1 - print(f"✓ Built: {lib_path}") + print(f"OK Built: {lib_path}") return 0 diff --git a/dispatcher/scripts/stress_test_autocorrect.py b/dispatcher/scripts/stress_test_autocorrect.py index 13e92abffa..63b250071e 100644 --- a/dispatcher/scripts/stress_test_autocorrect.py +++ b/dispatcher/scripts/stress_test_autocorrect.py @@ -34,9 +34,9 @@ from compile_gemm_examples import ( # noqa: E402 validate_kernel_config, expand_declaration_with_arch_filter, ) -from compile_conv_examples import ( # noqa: E402 - validate_conv_kernel_config, - expand_conv_declaration_with_arch_filter, +from compile_grouped_conv_examples import ( # noqa: E402 + validate_grouped_conv_kernel_config as validate_conv_kernel_config, + expand_grouped_conv_declaration_with_arch_filter as expand_conv_declaration_with_arch_filter, ) @@ -316,7 +316,7 @@ def test_python_autocorrect(verbose=False): if was_modified: print(f" Modified: {len(corrections)} correction(s)") for c in corrections: - print(f" • {c}") + print(f" - {c}") except Exception as e: results["failed"] += 1 @@ -465,7 +465,7 @@ def run_stress_test(arch, num_samples, verbose): } expanded = expand_declaration_with_arch_filter(config, test_arch) - status = "✓" if expanded else "✗" + status = "OK" if expanded else "FAIL" expected = test_arch in test["expected_archs"] match = "OK" if (bool(expanded) == expected) else "MISMATCH" diff --git a/dispatcher/src/dispatcher.cpp b/dispatcher/src/dispatcher.cpp index fdb400921e..2cb589adf2 100644 --- a/dispatcher/src/dispatcher.cpp +++ b/dispatcher/src/dispatcher.cpp @@ -2,17 +2,18 @@ // SPDX-License-Identifier: MIT #include "ck_tile/dispatcher/dispatcher.hpp" -#include +#include "ck_tile/dispatcher/dispatcher_error.hpp" #include #include namespace ck_tile { namespace dispatcher { -Dispatcher::Dispatcher(Registry* registry) +Dispatcher::Dispatcher(Registry* registry, const std::string& gfx_arch) : registry_(registry ? registry : &Registry::instance()), heuristic_(nullptr), - strategy_(SelectionStrategy::FirstFit) + strategy_(SelectionStrategy::FirstFit), + gfx_arch_(gfx_arch) { } @@ -61,7 +62,7 @@ float Dispatcher::run_fused(const void* a_ptr, std::ostringstream oss; oss << "No suitable kernel found for problem: M=" << problem.M << " N=" << problem.N << " K=" << problem.K; - throw std::runtime_error(oss.str()); + throw NoKernelFound(oss.str()); } return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream); @@ -78,7 +79,7 @@ float Dispatcher::run_explicit(const std::string& kernel_id, auto kernel = registry_->lookup(kernel_id); if(!kernel) { - throw std::runtime_error("Kernel not found: " + kernel_id); + throw NoKernelFound("Kernel not found: " + kernel_id); } if(!kernel->supports(problem)) @@ -86,7 +87,7 @@ float Dispatcher::run_explicit(const std::string& kernel_id, std::ostringstream oss; oss << "Kernel " << kernel_id << " does not support problem: M=" << problem.M << " N=" << problem.N << " K=" << problem.K; - throw std::runtime_error(oss.str()); + throw UnsupportedProblem(oss.str()); } return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream); diff --git a/dispatcher/src/registry.cpp b/dispatcher/src/registry.cpp index 0d83afd613..f565885181 100644 --- a/dispatcher/src/registry.cpp +++ b/dispatcher/src/registry.cpp @@ -5,39 +5,32 @@ #include "ck_tile/dispatcher/json_export.hpp" #include "ck_tile/dispatcher/arch_filter.hpp" #include +#include +#include namespace ck_tile { namespace dispatcher { -Registry::Registry() - : name_("default"), - auto_export_enabled_(false), - auto_export_include_statistics_(true), - auto_export_on_every_registration_(true) -{ -} +Registry::Registry() = default; Registry::~Registry() { - // Perform auto-export on destruction if enabled (regardless of export_on_every_registration - // setting) if(auto_export_enabled_) { perform_auto_export(); } } -Registry::Registry(Registry&& other) noexcept - : mutex_() // mutex is not movable, create new one - , - kernels_(std::move(other.kernels_)), - name_(std::move(other.name_)), - auto_export_enabled_(other.auto_export_enabled_), - auto_export_filename_(std::move(other.auto_export_filename_)), - auto_export_include_statistics_(other.auto_export_include_statistics_), - auto_export_on_every_registration_(other.auto_export_on_every_registration_) +Registry::Registry(Registry&& other) noexcept : Base(std::move(other)) { - // Disable auto-export on the moved-from object to prevent double export + // Base move constructor already locked+released other.mutex_. + // Re-acquire to safely read the remaining fields. + std::lock_guard lock(other.mutex()); + auto_export_enabled_ = other.auto_export_enabled_; + auto_export_filename_ = std::move(other.auto_export_filename_); + auto_export_include_statistics_ = other.auto_export_include_statistics_; + auto_export_on_every_registration_ = other.auto_export_on_every_registration_; + other.auto_export_enabled_ = false; } @@ -45,11 +38,7 @@ Registry& Registry::operator=(Registry&& other) noexcept { if(this != &other) { - std::lock_guard lock(mutex_); - std::lock_guard other_lock(other.mutex_); - - kernels_ = std::move(other.kernels_); - name_ = std::move(other.name_); + Base::operator=(std::move(other)); auto_export_enabled_ = other.auto_export_enabled_; auto_export_filename_ = std::move(other.auto_export_filename_); auto_export_include_statistics_ = other.auto_export_include_statistics_; @@ -64,55 +53,27 @@ Registry& Registry::operator=(Registry&& other) noexcept bool Registry::register_kernel(KernelInstancePtr instance, Priority priority) { if(!instance) - { return false; - } - const std::string identifier = instance->get_key().encode_identifier(); - - bool registered = false; + if(Base::register_kernel(instance->get_name(), instance, priority)) { - std::lock_guard lock(mutex_); - - auto it = kernels_.find(identifier); - if(it != kernels_.end()) + if(auto_export_enabled_ && auto_export_on_every_registration_) { - // Kernel with this identifier already exists - // Only replace if new priority is higher - if(priority > it->second.priority) - { - it->second.instance = instance; - it->second.priority = priority; - registered = true; - } - } - else - { - // New kernel, insert it - kernels_[identifier] = RegistryEntry{instance, priority}; - registered = true; + perform_auto_export(); } + return true; } - - // Perform auto-export if enabled and configured to export on every registration - if(registered && auto_export_enabled_ && auto_export_on_every_registration_) - { - perform_auto_export(); - } - - return registered; + return false; } KernelInstancePtr Registry::lookup(const std::string& identifier) const { - std::lock_guard lock(mutex_); - - auto it = kernels_.find(identifier); - if(it != kernels_.end()) + std::lock_guard lock(mutex()); + auto it = entries().find(identifier); + if(it != entries().end()) { return it->second.instance; } - return nullptr; } @@ -121,75 +82,23 @@ KernelInstancePtr Registry::lookup(const KernelKey& key) const return lookup(key.encode_identifier()); } -std::vector Registry::get_all() const -{ - std::lock_guard lock(mutex_); - - std::vector result; - result.reserve(kernels_.size()); - - for(const auto& pair : kernels_) - { - result.push_back(pair.second.instance); - } - - return result; -} +std::vector Registry::get_all() const { return Base::get_all_instances(); } std::vector Registry::filter(std::function predicate) const { - std::lock_guard lock(mutex_); - + std::lock_guard lock(mutex()); std::vector result; - - for(const auto& pair : kernels_) + for(const auto& [name, entry] : entries()) { - if(predicate(*pair.second.instance)) + if(predicate(*(entry.instance))) { - result.push_back(pair.second.instance); + result.push_back(entry.instance); } } - return result; } -std::size_t Registry::size() const -{ - std::lock_guard lock(mutex_); - return kernels_.size(); -} - -bool Registry::empty() const -{ - std::lock_guard lock(mutex_); - return kernels_.empty(); -} - -void Registry::clear() -{ - std::lock_guard lock(mutex_); - kernels_.clear(); -} - -const std::string& Registry::get_name() const -{ - std::lock_guard lock(mutex_); - return name_; -} - -void Registry::set_name(const std::string& name) -{ - std::lock_guard lock(mutex_); - name_ = name; -} - -Registry& Registry::instance() -{ - static Registry global_registry; - return global_registry; -} - std::string Registry::export_json(bool include_statistics) const { return export_registry_json(*this, include_statistics); @@ -204,7 +113,7 @@ void Registry::enable_auto_export(const std::string& filename, bool include_statistics, bool export_on_every_registration) { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex()); auto_export_enabled_ = true; auto_export_filename_ = filename; auto_export_include_statistics_ = include_statistics; @@ -213,13 +122,13 @@ void Registry::enable_auto_export(const std::string& filename, void Registry::disable_auto_export() { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex()); auto_export_enabled_ = false; } bool Registry::is_auto_export_enabled() const { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex()); return auto_export_enabled_; } @@ -230,7 +139,7 @@ void Registry::perform_auto_export() bool include_stats; { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex()); if(!auto_export_enabled_) { return; @@ -243,31 +152,15 @@ void Registry::perform_auto_export() export_json_to_file(filename, include_stats); } -std::size_t Registry::merge_from(const Registry& other, Priority priority) -{ - auto other_kernels = other.get_all(); - std::size_t merged_count = 0; - - for(const auto& kernel : other_kernels) - { - if(register_kernel(kernel, priority)) - { - merged_count++; - } - } - - return merged_count; -} - std::size_t Registry::filter_by_arch(const std::string& gpu_arch) { ArchFilter filter(gpu_arch); std::vector to_remove; { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex()); - for(const auto& pair : kernels_) + for(const auto& pair : entries()) { if(!filter.is_valid(pair.second.instance->get_key())) { @@ -277,12 +170,18 @@ std::size_t Registry::filter_by_arch(const std::string& gpu_arch) for(const auto& key : to_remove) { - kernels_.erase(key); + entries_mut().erase(key); } } return to_remove.size(); } +Registry& Registry::instance() +{ + static Registry global_registry; + return global_registry; +} + } // namespace dispatcher -} // namespace ck_tile +} // namespace ck_tile \ No newline at end of file diff --git a/dispatcher/tests/CMakeLists.txt b/dispatcher/tests/CMakeLists.txt index 6c20c18c95..a54feba284 100644 --- a/dispatcher/tests/CMakeLists.txt +++ b/dispatcher/tests/CMakeLists.txt @@ -217,6 +217,10 @@ endforeach() # Standalone integration tests (with their own main()) set(STANDALONE_TESTS test_minimal.cpp + test_grouped_conv_config.cpp + test_grouped_conv_problem.cpp + test_grouped_conv_kernel_decl.cpp + test_grouped_conv_registry.cpp ) foreach(test_source ${STANDALONE_TESTS}) diff --git a/dispatcher/tests/test_autocorrect.py b/dispatcher/tests/test_autocorrect.py index 0ec3ebda3c..3f52049f74 100644 --- a/dispatcher/tests/test_autocorrect.py +++ b/dispatcher/tests/test_autocorrect.py @@ -42,10 +42,10 @@ from compile_gemm_examples import ( # noqa: E402 expand_declaration_with_arch_filter, is_wildcard_declaration, ) -from compile_conv_examples import ( # noqa: E402 - validate_conv_kernel_config, - expand_conv_declaration_with_arch_filter, - is_conv_wildcard_declaration, +from compile_grouped_conv_examples import ( # noqa: E402 + validate_grouped_conv_kernel_config as validate_conv_kernel_config, + expand_grouped_conv_declaration_with_arch_filter as expand_conv_declaration_with_arch_filter, + is_grouped_conv_wildcard_declaration as is_conv_wildcard_declaration, ) from ctypes_utils import auto_correct_kernel_config, KernelConfig # noqa: E402 diff --git a/dispatcher/tests/test_codegen_common.py b/dispatcher/tests/test_codegen_common.py new file mode 100644 index 0000000000..2efeaefb4d --- /dev/null +++ b/dispatcher/tests/test_codegen_common.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Tests for codegen/codegen_common.py -- shared infrastructure for GEMM and grouped conv codegen. + +Phase 1a TDD: these tests are written BEFORE the implementation exists. +Run: python3 -m pytest tests/test_codegen_common.py -v +""" + +import sys +import unittest +from pathlib import Path + +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) + +from codegen_common import ( # noqa: E402 + TileConfig, + TraitConfigBase, + CommonTypeMappings, + generate_cpp_compilation_unit, + parallel_generate, + valid_wave_configs, + valid_warp_configs, + valid_trait_configs, + needs_wave_expansion, + needs_warp_expansion, + needs_pipeline_expansion, +) + + +class TestTileConfig(unittest.TestCase): + """TileConfig dataclass tests.""" + + def test_valid_config(self): + tc = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + self.assertTrue(tc.is_valid()) + + def test_zero_tile_invalid(self): + tc = TileConfig(0, 128, 32, 2, 2, 1, 32, 32, 16) + self.assertFalse(tc.is_valid()) + + def test_non_divisible_invalid(self): + tc = TileConfig(127, 128, 32, 2, 2, 1, 32, 32, 16) + self.assertFalse(tc.is_valid()) + + def test_all_fields_accessible(self): + tc = TileConfig(256, 128, 64, 4, 1, 1, 32, 32, 16) + self.assertEqual(tc.tile_m, 256) + self.assertEqual(tc.tile_n, 128) + self.assertEqual(tc.tile_k, 64) + self.assertEqual(tc.warp_m, 4) + self.assertEqual(tc.warp_n, 1) + self.assertEqual(tc.warp_k, 1) + self.assertEqual(tc.warp_tile_m, 32) + self.assertEqual(tc.warp_tile_n, 32) + self.assertEqual(tc.warp_tile_k, 16) + + def test_small_valid_config(self): + tc = TileConfig(16, 16, 16, 1, 1, 1, 16, 16, 16) + self.assertTrue(tc.is_valid()) + + +class TestTraitConfigBase(unittest.TestCase): + """TraitConfigBase dataclass tests.""" + + def test_valid_intrawave(self): + tc = TraitConfigBase("compv3", "cshuffle", "intrawave", False, False, False) + self.assertTrue(tc.is_valid()) + + def test_invalid_interwave_compv3(self): + tc = TraitConfigBase("compv3", "cshuffle", "interwave", False, False, False) + self.assertFalse(tc.is_valid()) + + def test_invalid_interwave_compv4(self): + tc = TraitConfigBase("compv4", "cshuffle", "interwave", False, False, False) + self.assertFalse(tc.is_valid()) + + def test_valid_mem_interwave(self): + tc = TraitConfigBase("mem", "cshuffle", "interwave", False, False, False) + self.assertTrue(tc.is_valid()) + + def test_valid_mem_intrawave(self): + tc = TraitConfigBase("mem", "cshuffle", "intrawave", False, False, False) + self.assertTrue(tc.is_valid()) + + def test_padding_fields(self): + tc = TraitConfigBase("compv3", "cshuffle", "intrawave", True, True, True) + self.assertTrue(tc.pad_m) + self.assertTrue(tc.pad_n) + self.assertTrue(tc.pad_k) + + +class TestCommonTypeMappings(unittest.TestCase): + """CommonTypeMappings tests.""" + + def test_dtype_to_ck(self): + self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["fp16"], "fp16_t") + self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["bf16"], "bf16_t") + self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["fp32"], "float") + self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["fp8"], "fp8_t") + + def test_pipeline_to_ck(self): + self.assertEqual( + CommonTypeMappings.PIPELINE_TO_CK["mem"], "GemmPipelineAgBgCrMem" + ) + self.assertIn("compv3", CommonTypeMappings.PIPELINE_TO_CK) + self.assertIn("compv4", CommonTypeMappings.PIPELINE_TO_CK) + + def test_pipeline_to_base(self): + self.assertIn("mem", CommonTypeMappings.PIPELINE_TO_BASE) + self.assertIn("compv3", CommonTypeMappings.PIPELINE_TO_BASE) + self.assertIn("compv4", CommonTypeMappings.PIPELINE_TO_BASE) + + def test_scheduler_to_ck(self): + self.assertIn("intrawave", CommonTypeMappings.SCHEDULER_TO_CK) + self.assertIn("interwave", CommonTypeMappings.SCHEDULER_TO_CK) + + def test_epilogue_to_dispatcher(self): + self.assertIn("cshuffle", CommonTypeMappings.EPILOGUE_TO_DISPATCHER) + self.assertIn("default", CommonTypeMappings.EPILOGUE_TO_DISPATCHER) + + def test_layout_to_ck(self): + self.assertIn("r", CommonTypeMappings.LAYOUT_TO_CK) + self.assertIn("c", CommonTypeMappings.LAYOUT_TO_CK) + + def test_get_output_dtype(self): + self.assertEqual(CommonTypeMappings.get_output_dtype("fp8"), "fp16") + self.assertEqual(CommonTypeMappings.get_output_dtype("bf8"), "fp16") + self.assertEqual(CommonTypeMappings.get_output_dtype("fp16"), "fp16") + self.assertEqual(CommonTypeMappings.get_output_dtype("fp32"), "fp32") + + +class TestGenerateCppCompilationUnit(unittest.TestCase): + """Tests for generate_cpp_compilation_unit.""" + + def test_includes_kernel_header(self): + result = generate_cpp_compilation_unit("my_kernel") + self.assertIn('#include "my_kernel.hpp"', result) + + def test_contains_pragma_once_or_guard(self): + result = generate_cpp_compilation_unit("test_kernel") + self.assertIn("test_kernel", result) + + def test_different_names_different_output(self): + a = generate_cpp_compilation_unit("kernel_a") + b = generate_cpp_compilation_unit("kernel_b") + self.assertNotEqual(a, b) + + +class TestParallelGenerate(unittest.TestCase): + """Tests for parallel_generate helper.""" + + def _dummy_generate(self, item): + return f"generated_{item}" + + def test_parallel_returns_all(self): + items = ["a", "b", "c", "d"] + results = parallel_generate(self._dummy_generate, items, parallel=True) + self.assertEqual(len(results), 4) + for item in items: + self.assertIn(f"generated_{item}", results) + + def test_sequential_returns_all(self): + items = ["x", "y", "z"] + results = parallel_generate(self._dummy_generate, items, parallel=False) + self.assertEqual(len(results), 3) + for item in items: + self.assertIn(f"generated_{item}", results) + + def test_empty_items(self): + results = parallel_generate(self._dummy_generate, [], parallel=True) + self.assertEqual(len(results), 0) + + def test_logs_per_kernel_progress(self): + items = ["k1", "k2"] + with self.assertLogs(level="INFO") as cm: + parallel_generate(self._dummy_generate, items, parallel=False) + log_output = "\n".join(cm.output) + self.assertIn("k1", log_output) + self.assertIn("k2", log_output) + + +class TestArchAwareExpansion(unittest.TestCase): + """Tests for arch-aware expansion helpers (best-of-conv).""" + + def test_valid_wave_configs_gfx942(self): + configs = valid_wave_configs("gfx942") + self.assertIsInstance(configs, list) + self.assertIn([2, 2, 1], configs) + self.assertIn([1, 4, 1], configs) + + def test_valid_wave_configs_unknown_arch(self): + configs = valid_wave_configs("gfx_unknown") + self.assertIsInstance(configs, list) + self.assertGreater(len(configs), 0) + + def test_valid_warp_configs_gfx942_fp16(self): + configs = valid_warp_configs("gfx942", "fp16") + self.assertIsInstance(configs, list) + self.assertIn([32, 32, 16], configs) + + def test_valid_warp_configs_unknown_arch(self): + configs = valid_warp_configs("gfx_unknown", "fp16") + self.assertIsInstance(configs, list) + self.assertGreater(len(configs), 0) + + def test_valid_trait_configs_excludes_interwave_compute(self): + configs = valid_trait_configs() + self.assertIsInstance(configs, list) + self.assertNotIn(("compv3", "cshuffle", "interwave"), configs) + self.assertNotIn(("compv4", "cshuffle", "interwave"), configs) + + def test_valid_trait_configs_includes_mem_interwave(self): + configs = valid_trait_configs() + has_mem_interwave = any(p == "mem" and s == "interwave" for p, s in configs) + self.assertTrue(has_mem_interwave) + + def test_needs_wave_expansion_wildcard(self): + self.assertTrue(needs_wave_expansion({"wave_m": -1, "wave_n": 2})) + self.assertTrue(needs_wave_expansion({"wave_m": 2, "wave_n": -1})) + + def test_needs_wave_expansion_explicit(self): + self.assertFalse(needs_wave_expansion({"wave_m": 2, "wave_n": 2})) + + def test_needs_warp_expansion_wildcard(self): + self.assertTrue(needs_warp_expansion({"warp_m": -1, "warp_n": 32})) + + def test_needs_warp_expansion_explicit(self): + self.assertFalse(needs_warp_expansion({"warp_m": 32, "warp_n": 32})) + + def test_needs_pipeline_expansion_wildcard(self): + self.assertTrue(needs_pipeline_expansion({"pipeline": "*"})) + + def test_needs_pipeline_expansion_explicit(self): + self.assertFalse(needs_pipeline_expansion({"pipeline": "compv4"})) + + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher/tests/test_dispatcher_common.py b/dispatcher/tests/test_dispatcher_common.py new file mode 100644 index 0000000000..2c0fc8307c --- /dev/null +++ b/dispatcher/tests/test_dispatcher_common.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Tests for python/dispatcher_common.py -- shared Python dispatcher utilities. + +Phase 1b TDD: tests written BEFORE implementation exists. +Run: python3 -m pytest tests/test_dispatcher_common.py -v +""" + +import io +import sys +import unittest +from pathlib import Path +from unittest.mock import patch + +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "python")) +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) + +from dispatcher_common import ( # noqa: E402 + get_dispatcher_root, + get_ck_root, + get_build_dir, + get_generated_kernels_dir, + get_arch_filter_data, + ValidationResultBase, + validate_wave_config, + validate_warp_tile_config, + validate_trait_combo, + auto_correct_wave, + auto_correct_trait, + Colors, + print_phase, + print_success, + print_error, + print_info, + cleanup_generated_kernels, +) + + +class TestPathHelpers(unittest.TestCase): + """Tests for path helper functions.""" + + def test_dispatcher_root_contains_codegen(self): + root = get_dispatcher_root() + self.assertTrue((root / "codegen").exists()) + + def test_ck_root_contains_include_or_is_parent(self): + root = get_ck_root() + self.assertTrue(root.exists()) + self.assertEqual(root, get_dispatcher_root().parent) + + def test_build_dir_is_under_dispatcher(self): + build = get_build_dir() + self.assertEqual(build.parent, get_dispatcher_root()) + + def test_generated_kernels_dir_under_build(self): + gen_dir = get_generated_kernels_dir() + self.assertEqual(gen_dir.parent, get_build_dir()) + + +class TestGetArchFilterData(unittest.TestCase): + """Tests for get_arch_filter_data.""" + + def test_returns_dict(self): + data = get_arch_filter_data() + self.assertIsInstance(data, dict) + + def test_has_warp_combos(self): + data = get_arch_filter_data() + self.assertIn("warp_combos", data) + + def test_has_warp_tile_combos(self): + data = get_arch_filter_data() + self.assertIn("warp_tile_combos", data) + + def test_has_trait_unsupported(self): + data = get_arch_filter_data() + self.assertIn("trait_unsupported", data) + + def test_has_supported_archs(self): + data = get_arch_filter_data() + self.assertIn("supported_archs", data) + self.assertIn("gfx942", data["supported_archs"]) + + def test_gfx942_wave_configs(self): + data = get_arch_filter_data() + gfx942 = data["warp_combos"].get("gfx942", []) + self.assertIn([2, 2, 1], gfx942) + + +class TestValidationResultBase(unittest.TestCase): + """Tests for ValidationResultBase dataclass.""" + + def test_valid_result(self): + vr = ValidationResultBase(is_valid=True) + self.assertTrue(vr.is_valid) + self.assertEqual(vr.errors, []) + self.assertEqual(vr.warnings, []) + self.assertEqual(vr.suggested_fixes, {}) + + def test_invalid_result(self): + vr = ValidationResultBase( + is_valid=False, + errors=["bad wave"], + suggested_fixes={"wave_m": 2}, + ) + self.assertFalse(vr.is_valid) + self.assertEqual(len(vr.errors), 1) + self.assertIn("wave_m", vr.suggested_fixes) + + +class TestValidateWaveConfig(unittest.TestCase): + """Tests for validate_wave_config.""" + + def test_valid_wave(self): + is_valid, msg = validate_wave_config([2, 2, 1], "gfx942") + self.assertTrue(is_valid) + self.assertEqual(msg, "") + + def test_invalid_wave(self): + is_valid, msg = validate_wave_config([3, 3, 1], "gfx942") + self.assertFalse(is_valid) + self.assertIn("wave", msg.lower()) + + +class TestValidateWarpTileConfig(unittest.TestCase): + """Tests for validate_warp_tile_config.""" + + def test_valid_warp_tile(self): + is_valid, msg = validate_warp_tile_config([32, 32, 16], "gfx942", "fp16") + self.assertTrue(is_valid) + + def test_invalid_warp_tile(self): + is_valid, msg = validate_warp_tile_config([99, 99, 99], "gfx942", "fp16") + self.assertFalse(is_valid) + self.assertIn("warp", msg.lower()) + + +class TestValidateTraitCombo(unittest.TestCase): + """Tests for validate_trait_combo.""" + + def test_valid_trait(self): + is_valid, msg = validate_trait_combo("compv3", "cshuffle", "intrawave") + self.assertTrue(is_valid) + + def test_invalid_trait_interwave_compute(self): + is_valid, msg = validate_trait_combo("compv4", "cshuffle", "interwave") + self.assertFalse(is_valid) + + def test_valid_mem_interwave(self): + is_valid, msg = validate_trait_combo("mem", "cshuffle", "interwave") + self.assertTrue(is_valid) + + +class TestAutoCorrectWave(unittest.TestCase): + """Tests for auto_correct_wave.""" + + def test_corrects_invalid_wave(self): + corrected = auto_correct_wave([1, 1, 1], "gfx942") + self.assertIsInstance(corrected, list) + self.assertEqual(len(corrected), 3) + data = get_arch_filter_data() + valid_waves = data["warp_combos"].get("gfx942", [[2, 2, 1]]) + self.assertIn(corrected, valid_waves) + + +class TestAutoCorrectTrait(unittest.TestCase): + """Tests for auto_correct_trait.""" + + def test_corrects_invalid_scheduler(self): + corrected_pipeline, corrected_scheduler = auto_correct_trait( + "compv4", "interwave" + ) + self.assertEqual(corrected_scheduler, "intrawave") + + +class TestColors(unittest.TestCase): + """Tests for Colors class (cross-platform ANSI support from conv).""" + + def test_green_returns_string(self): + result = Colors.green("ok") + self.assertIn("ok", result) + + def test_red_returns_string(self): + result = Colors.red("error") + self.assertIn("error", result) + + def test_yellow_returns_string(self): + result = Colors.yellow("warn") + self.assertIn("warn", result) + + def test_bold_returns_string(self): + result = Colors.bold("title") + self.assertIn("title", result) + + def test_plain_mode_no_ansi(self): + with patch.object(Colors, "_use_color", return_value=False): + result = Colors.green("plain") + self.assertEqual(result, "plain") + + +class TestPhasedOutput(unittest.TestCase): + """Tests for phased output helpers.""" + + def test_print_phase(self): + buf = io.StringIO() + with patch("sys.stdout", buf): + print_phase(1, "Setup") + self.assertIn("Setup", buf.getvalue()) + + def test_print_success(self): + buf = io.StringIO() + with patch("sys.stdout", buf): + print_success("Done") + self.assertIn("Done", buf.getvalue()) + + def test_print_error(self): + buf = io.StringIO() + with patch("sys.stdout", buf): + print_error("Oops") + self.assertIn("Oops", buf.getvalue()) + + def test_print_info(self): + buf = io.StringIO() + with patch("sys.stdout", buf): + print_info("FYI") + self.assertIn("FYI", buf.getvalue()) + + +class TestCleanup(unittest.TestCase): + """Tests for cleanup_generated_kernels.""" + + def test_cleanup_nonexistent_dir_no_error(self): + cleanup_generated_kernels(Path("/tmp/nonexistent_ck_test_dir_12345")) + + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher/tests/test_examples_integration.py b/dispatcher/tests/test_examples_integration.py index cfd18a3305..d02ea69787 100644 --- a/dispatcher/tests/test_examples_integration.py +++ b/dispatcher/tests/test_examples_integration.py @@ -28,14 +28,18 @@ sys.path.insert(0, str(PYTHON_DIR)) def run_python_example( - example_path: Path, timeout: int = 120 + example_path: Path, timeout: int = 120, extra_args: list = None ) -> subprocess.CompletedProcess: """Run a Python example and capture output.""" env = os.environ.copy() env["PYTHONPATH"] = str(PYTHON_DIR) + cmd = [sys.executable, str(example_path)] + if extra_args: + cmd.extend(extra_args) + return subprocess.run( - [sys.executable, str(example_path)], + cmd, capture_output=True, text=True, timeout=timeout, @@ -111,61 +115,74 @@ class TestGemmPythonExamples(unittest.TestCase): result = run_python_example(example) self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - # Should pass validation self.assertIn("PASS", result.stdout.upper(), "Validation should pass") class TestConvPythonExamples(unittest.TestCase): - """Test Conv Python examples.""" + """Test grouped conv Python examples.""" @classmethod def setUpClass(cls): """Check if examples directory exists.""" - cls.conv_examples_dir = EXAMPLES_DIR / "conv" / "python" + cls.conv_examples_dir = EXAMPLES_DIR / "grouped_conv" / "python" if not cls.conv_examples_dir.exists(): - raise unittest.SkipTest("Conv Python examples not found") + raise unittest.SkipTest("Grouped conv Python examples not found") - def test_01_basic_conv(self): - """Test basic conv example.""" - example = self.conv_examples_dir / "01_basic_conv.py" + def test_01_basic_grouped_conv(self): + """Test basic grouped conv example.""" + example = self.conv_examples_dir / "01_basic_grouped_conv.py" if not example.exists(): self.skipTest(f"{example.name} not found") - result = run_python_example(example) - self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + self.assertIn("PASS", result.stdout.upper()) - def test_02_conv2d_fwd(self): - """Test 2D forward conv example.""" - example = self.conv_examples_dir / "02_conv2d_fwd.py" + def test_02_forward(self): + """Test forward conv example (2D + 3D).""" + example = self.conv_examples_dir / "02_forward.py" if not example.exists(): self.skipTest(f"{example.name} not found") - result = run_python_example(example) - self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper()) - def test_03_conv3d_fwd(self): - """Test 3D forward conv example.""" - example = self.conv_examples_dir / "03_conv3d_fwd.py" + def test_03_bwd_data(self): + """Test backward data example.""" + example = self.conv_examples_dir / "03_bwd_data.py" if not example.exists(): self.skipTest(f"{example.name} not found") - result = run_python_example(example) - self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper()) - def test_07_validation(self): - """Test validation example.""" - example = self.conv_examples_dir / "07_validation.py" + def test_04_bwd_weight(self): + """Test backward weight example.""" + example = self.conv_examples_dir / "04_bwd_weight.py" if not example.exists(): self.skipTest(f"{example.name} not found") - result = run_python_example(example) - self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + self.assertIn("PASS", result.stdout.upper()) + + def test_05_benchmark(self): + """Test benchmark example.""" + example = self.conv_examples_dir / "05_benchmark.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + result = run_python_example( + example, extra_args=["--warmup", "1", "--repeat", "1"] + ) + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper()) + + def test_06_registry_json(self): + """Test registry + heuristic + JSON example.""" + example = self.conv_examples_dir / "06_registry_json.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + result = run_python_example(example) + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper()) class TestGemmCppExamples(unittest.TestCase): @@ -195,18 +212,18 @@ class TestGemmCppExamples(unittest.TestCase): self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - def test_gemm_04_validation(self): - """Test validation GEMM C++ example.""" - result = run_cpp_example("gemm_04_validation") + def test_gemm_03_benchmark_validation(self): + """Test benchmark+validation GEMM C++ example.""" + result = run_cpp_example("gemm_03_benchmark_validation") if result is None: - self.skipTest("gemm_04_validation not built") + self.skipTest("gemm_03_benchmark_validation not built") self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") self.assertIn("PASS", result.stdout.upper(), "Validation should pass") class TestConvCppExamples(unittest.TestCase): - """Test Conv C++ examples.""" + """Test grouped conv C++ examples.""" @classmethod def setUpClass(cls): @@ -215,23 +232,29 @@ class TestConvCppExamples(unittest.TestCase): if not cls.examples_dir.exists(): raise unittest.SkipTest("C++ examples not built") - def test_conv_01_forward(self): - """Test forward conv C++ example.""" - result = run_cpp_example("conv_01_forward") + def test_grouped_conv_01_basic(self): + """Test basic grouped conv C++ example.""" + result = run_cpp_example("grouped_conv_01_basic") if result is None: - self.skipTest("conv_01_forward not built") - + self.skipTest("grouped_conv_01_basic not built") self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + self.assertIn("PASS", result.stdout.upper()) - def test_conv_02_validation(self): - """Test validation conv C++ example.""" - result = run_cpp_example("conv_02_validation") + def test_grouped_conv_02_all_dirs(self): + """Test all directions grouped conv C++ example.""" + result = run_cpp_example("grouped_conv_02_all_dirs") if result is None: - self.skipTest("conv_02_validation not built") - + self.skipTest("grouped_conv_02_all_dirs not built") self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + self.assertIn("PASS", result.stdout.upper()) + + def test_grouped_conv_03_bench_val(self): + """Test benchmark+validation grouped conv C++ example.""" + result = run_cpp_example("grouped_conv_03_bench_val") + if result is None: + self.skipTest("grouped_conv_03_bench_val not built") + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper()) class TestUtilityImports(unittest.TestCase): @@ -246,14 +269,18 @@ class TestUtilityImports(unittest.TestCase): except ImportError as e: self.fail(f"Failed to import ctypes_utils: {e}") - def test_import_conv_utils(self): - """Test importing conv_utils.""" + def test_import_grouped_conv_utils(self): + """Test importing grouped_conv_utils.""" try: - from conv_utils import ConvSignature, ConvAlgorithm, ConvProblem # noqa: F401 + from grouped_conv_utils import ( # noqa: F401 + GroupedConvValidationResult, + validate_grouped_conv_config, + GroupedConvDataType, + ) self.assertTrue(True) except ImportError as e: - self.fail(f"Failed to import conv_utils: {e}") + self.fail(f"Failed to import grouped_conv_utils: {e}") def test_kernel_config_creation(self): """Test creating a KernelConfig.""" @@ -272,22 +299,19 @@ class TestUtilityImports(unittest.TestCase): self.assertEqual(config.dtype_a, "fp16") self.assertEqual(config.layout_a, "row") - def test_conv_signature_creation(self): - """Test creating a ConvSignature.""" - from conv_utils import ConvSignature + def test_grouped_conv_default_config(self): + """Test creating a grouped conv default config.""" + from grouped_conv_utils import get_grouped_conv_default_config - sig = ConvSignature( - dtype_in="fp16", - dtype_wei="fp16", - dtype_out="fp16", - dtype_acc="fp32", - layout="nhwgc", - direction="forward", - num_dims=2, + config = get_grouped_conv_default_config( + variant="forward", + ndim_spatial=2, + arch="gfx942", ) - self.assertEqual(sig.dtype_in, "fp16") - self.assertEqual(sig.direction, "forward") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertEqual(d["variant"], "forward") + self.assertEqual(d["arch"], "gfx942") class TestAutoCorrection(unittest.TestCase): @@ -316,21 +340,22 @@ class TestAutoCorrection(unittest.TestCase): self.assertTrue(was_modified, "Config should be modified") self.assertGreater(len(corrections), 0, "Should have corrections") - def test_conv_auto_correct(self): - """Test Conv auto-correction.""" - from conv_utils import auto_correct_conv_config - - # Call with invalid wave config parameters - corrected, was_modified, corrections = auto_correct_conv_config( - wave_m=99, # Invalid - wave_n=99, # Invalid - wave_k=99, # Invalid - dtype="fp16", - arch="gfx942", + def test_grouped_conv_auto_correct(self): + """Test Grouped Conv auto-correction.""" + from grouped_conv_utils import ( + auto_correct_grouped_conv_config, + get_grouped_conv_default_config, ) - self.assertTrue(was_modified, "Config should be modified") - self.assertGreater(len(corrections), 0, "Should have corrections") + config = get_grouped_conv_default_config() + d = config.to_dict() if hasattr(config, "to_dict") else config + d["tile_config"]["warp_m"] = [99] + d["tile_config"]["warp_n"] = [99] + + corrected, result = auto_correct_grouped_conv_config(d) + + self.assertIsInstance(corrected, dict) + self.assertIn("tile_config", corrected) if __name__ == "__main__": diff --git a/dispatcher/tests/test_grouped_conv_codegen.py b/dispatcher/tests/test_grouped_conv_codegen.py new file mode 100644 index 0000000000..acfa5abd8f --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_codegen.py @@ -0,0 +1,589 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +TDD tests for codegen/unified_grouped_conv_codegen.py -- grouped convolution code generator. + +These tests are written BEFORE the implementation exists. +Run: python3 -m pytest dispatcher/tests/test_grouped_conv_codegen.py -v +""" + +import sys +import unittest +from pathlib import Path +from unittest.mock import patch + +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) +sys.path.insert(0, str(DISPATCHER_DIR / "python")) + +from codegen_common import TileConfig, TraitConfigBase # noqa: E402 + +from unified_grouped_conv_codegen import ( # noqa: E402 + GroupedConvVariant, + GroupedConvLayout, + GroupedConvKernelConfig, + GroupedConvTypeMappings, + GroupedConvTraitConfig, + CKTileGroupedConvKernelGenerator, + GroupedConvDispatcherWrapperGenerator, + UnifiedGroupedConvCodegen, +) + + +# ============================================================================= +# TestGroupedConvVariant +# ============================================================================= + + +class TestGroupedConvVariant(unittest.TestCase): + """Test GroupedConvVariant enum values.""" + + def test_forward_value(self): + self.assertEqual(GroupedConvVariant.FORWARD.value, "forward") + + def test_backward_data_value(self): + self.assertEqual(GroupedConvVariant.BACKWARD_DATA.value, "bwd_data") + + def test_backward_weight_value(self): + self.assertEqual(GroupedConvVariant.BACKWARD_WEIGHT.value, "bwd_weight") + + def test_all_variants_exist(self): + self.assertIn(GroupedConvVariant.FORWARD, GroupedConvVariant) + self.assertIn(GroupedConvVariant.BACKWARD_DATA, GroupedConvVariant) + self.assertIn(GroupedConvVariant.BACKWARD_WEIGHT, GroupedConvVariant) + + +# ============================================================================= +# TestGroupedConvLayout +# ============================================================================= + + +class TestGroupedConvLayout(unittest.TestCase): + """Test GroupedConvLayout enum for 1D/2D/3D layouts.""" + + def test_nhwgc_value(self): + self.assertEqual(GroupedConvLayout.NHWGC.value, "NHWGC") + + def test_gkyxc_value(self): + self.assertEqual(GroupedConvLayout.GKYXC.value, "GKYXC") + + def test_nhwgk_value(self): + self.assertEqual(GroupedConvLayout.NHWGK.value, "NHWGK") + + def test_1d_layouts_exist(self): + """1D conv layouts (e.g., NWGC, GYXC, NWGK).""" + layouts_1d = [ + lay + for lay in GroupedConvLayout + if "W" in lay.value and "H" not in lay.value + ] + self.assertGreater(len(layouts_1d), 0) + + def test_2d_layouts_exist(self): + """2D conv layouts (e.g., NHWGC, GKYXC, NHWGK).""" + layouts_2d = [lay for lay in GroupedConvLayout if "HW" in lay.value] + self.assertGreater(len(layouts_2d), 0) + + def test_3d_layouts_exist(self): + """3D conv layouts (e.g., NDHWGC, GDKYXC).""" + layouts_3d = [ + lay for lay in GroupedConvLayout if "D" in lay.value or "DHW" in lay.value + ] + self.assertGreater(len(layouts_3d), 0) + + +# ============================================================================= +# TestGroupedConvKernelConfig +# ============================================================================= + + +class TestGroupedConvKernelConfig(unittest.TestCase): + """Test GroupedConvKernelConfig dataclass.""" + + def _make_tile(self): + return TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + + def _make_trait(self): + return GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=False, + num_groups_to_merge=1, + ) + + def test_name_contains_grouped_conv_fwd(self): + config = GroupedConvKernelConfig( + tile=self._make_tile(), + trait=self._make_trait(), + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + name = config.name("fp16") + self.assertIn("grouped_conv_fwd", name) + + def test_name_backward_data_contains_bwd_data(self): + config = GroupedConvKernelConfig( + tile=self._make_tile(), + trait=self._make_trait(), + variant=GroupedConvVariant.BACKWARD_DATA, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + name = config.name("fp16") + self.assertIn("bwd_data", name) + + def test_is_valid_for_arch_supported(self): + config = GroupedConvKernelConfig( + tile=self._make_tile(), + trait=self._make_trait(), + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + self.assertTrue(config.is_valid_for_arch("gfx942")) + + def test_is_valid_for_arch_unsupported(self): + config = GroupedConvKernelConfig( + tile=self._make_tile(), + trait=self._make_trait(), + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + self.assertFalse(config.is_valid_for_arch("gfx600")) + + +# ============================================================================= +# TestGroupedConvTypeMappings +# ============================================================================= + + +class TestGroupedConvTypeMappings(unittest.TestCase): + """Test GroupedConvTypeMappings class.""" + + def test_dtype_to_ck_fp16(self): + self.assertEqual(GroupedConvTypeMappings.DTYPE_TO_CK["fp16"], "half_t") + + def test_dtype_to_ck_bf16(self): + self.assertIn("bf16", GroupedConvTypeMappings.DTYPE_TO_CK) + + def test_dtype_to_ck_fp32(self): + self.assertIn("fp32", GroupedConvTypeMappings.DTYPE_TO_CK) + + def test_get_layouts_2d_has_in_wei_out_keys(self): + layouts = GroupedConvTypeMappings.get_layouts(2) + self.assertIn("in", layouts) + self.assertIn("wei", layouts) + self.assertIn("out", layouts) + + def test_get_layouts_2d_returns_dict(self): + layouts = GroupedConvTypeMappings.get_layouts(2) + self.assertIsInstance(layouts, dict) + + def test_get_layouts_1d(self): + layouts = GroupedConvTypeMappings.get_layouts(1) + self.assertIn("in", layouts) + self.assertIn("wei", layouts) + self.assertIn("out", layouts) + + def test_get_layouts_3d(self): + layouts = GroupedConvTypeMappings.get_layouts(3) + self.assertIn("in", layouts) + self.assertIn("wei", layouts) + self.assertIn("out", layouts) + + +# ============================================================================= +# TestCKTileGroupedConvKernelGenerator +# ============================================================================= + + +class TestCKTileGroupedConvKernelGenerator(unittest.TestCase): + """Test CKTileGroupedConvKernelGenerator.generate().""" + + def _make_config(self): + tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + trait = GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=False, + num_groups_to_merge=1, + ) + return GroupedConvKernelConfig( + tile=tile, + trait=trait, + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + + def test_generate_contains_pragma_once(self): + gen = CKTileGroupedConvKernelGenerator("fp16") + config = self._make_config() + result = gen.generate(config) + self.assertIn("#pragma once", result) + + def test_generate_contains_forward_kernel_include(self): + gen = CKTileGroupedConvKernelGenerator("fp16") + config = self._make_config() + result = gen.generate(config) + self.assertIn("grouped_convolution_forward_kernel.hpp", result) + + def test_generate_returns_non_empty_string(self): + gen = CKTileGroupedConvKernelGenerator("fp16") + config = self._make_config() + result = gen.generate(config) + self.assertIsInstance(result, str) + self.assertGreater(len(result), 100) + + def test_generate_valid_cpp_structure(self): + gen = CKTileGroupedConvKernelGenerator("fp16") + config = self._make_config() + result = gen.generate(config) + self.assertIn("#include", result) + self.assertIn("ck_tile", result) + + +# ============================================================================= +# TestGroupedConvDispatcherWrapperGenerator +# ============================================================================= + + +class TestGroupedConvDispatcherWrapperGenerator(unittest.TestCase): + """Test GroupedConvDispatcherWrapperGenerator.generate().""" + + def _make_config(self): + tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + trait = GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=False, + num_groups_to_merge=1, + ) + return GroupedConvKernelConfig( + tile=tile, + trait=trait, + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + + def test_generate_contains_dispatcher_registration(self): + gen = GroupedConvDispatcherWrapperGenerator("fp16") + config = self._make_config() + kernel_path = DISPATCHER_DIR / "build" / "generated" / "test_kernel.hpp" + output_dir = DISPATCHER_DIR / "build" / "generated" + result = gen.generate(config, kernel_path, output_dir) + self.assertIn("dispatcher", result) + self.assertIn("KernelKey", result) + self.assertIn("KernelInstancePtr", result) + + def test_generate_contains_pragma_once(self): + gen = GroupedConvDispatcherWrapperGenerator("fp16") + config = self._make_config() + kernel_path = DISPATCHER_DIR / "build" / "generated" / "test_kernel.hpp" + output_dir = DISPATCHER_DIR / "build" / "generated" + result = gen.generate(config, kernel_path, output_dir) + self.assertIn("#pragma once", result) + + def test_generate_valid_cpp(self): + gen = GroupedConvDispatcherWrapperGenerator("fp16") + config = self._make_config() + kernel_path = DISPATCHER_DIR / "build" / "generated" / "test_kernel.hpp" + output_dir = DISPATCHER_DIR / "build" / "generated" + result = gen.generate(config, kernel_path, output_dir) + self.assertIn("#include", result) + self.assertIn("namespace", result) + + +# ============================================================================= +# TestUnifiedGroupedConvCodegen +# ============================================================================= + + +class TestUnifiedGroupedConvCodegen(unittest.TestCase): + """Test UnifiedGroupedConvCodegen.generate_all().""" + + def test_generate_all_returns_dict_with_expected_keys(self): + output_dir = DISPATCHER_DIR / "build" / "generated" / "grouped_conv" + output_dir.mkdir(parents=True, exist_ok=True) + codegen = UnifiedGroupedConvCodegen( + output_dir=output_dir, + datatype="fp16", + ndim_spatial=2, + gpu_target="gfx942", + ) + with patch.object( + codegen, + "_get_configs", + return_value=[], # Mock empty config list for fast test + ): + results = codegen.generate_all(parallel=False) + self.assertIn("kernels", results) + self.assertIn("wrappers", results) + self.assertIn("failed", results) + self.assertIsInstance(results["kernels"], list) + self.assertIsInstance(results["wrappers"], list) + self.assertIsInstance(results["failed"], list) + + def test_generate_all_with_mock_config_produces_output(self): + output_dir = DISPATCHER_DIR / "build" / "generated" / "grouped_conv_test" + output_dir.mkdir(parents=True, exist_ok=True) + codegen = UnifiedGroupedConvCodegen( + output_dir=output_dir, + datatype="fp16", + ndim_spatial=2, + gpu_target="gfx942", + ) + # Use a real config - patch the config source to return one config + tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + trait = GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=False, + num_groups_to_merge=1, + ) + config = GroupedConvKernelConfig( + tile=tile, + trait=trait, + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + + with patch.object(codegen, "_get_configs", return_value=[config]): + results = codegen.generate_all(parallel=False) + self.assertIsInstance(results, dict) + self.assertIn("kernels", results) + + +# ============================================================================= +# TestSharedImports +# ============================================================================= + + +class TestSharedImports(unittest.TestCase): + """Verify TileConfig from codegen_common and GroupedConvTraitConfig extends TraitConfigBase.""" + + def test_tile_config_has_expected_fields(self): + """TileConfig from codegen_common has tile_m, tile_n, tile_k, etc.""" + tc = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + self.assertEqual(tc.tile_m, 128) + self.assertEqual(tc.tile_n, 128) + self.assertEqual(tc.tile_k, 32) + self.assertEqual(tc.warp_m, 2) + self.assertEqual(tc.warp_n, 2) + self.assertEqual(tc.warp_k, 1) + self.assertEqual(tc.warp_tile_m, 32) + self.assertEqual(tc.warp_tile_n, 32) + self.assertEqual(tc.warp_tile_k, 16) + + def test_tile_config_is_from_codegen_common(self): + """TileConfig used by grouped conv is the same as codegen_common.TileConfig.""" + tc = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + self.assertTrue(tc.is_valid()) + + def test_grouped_conv_trait_config_extends_trait_config_base(self): + """GroupedConvTraitConfig extends TraitConfigBase.""" + self.assertTrue(issubclass(GroupedConvTraitConfig, TraitConfigBase)) + + def test_grouped_conv_trait_config_has_double_smem_buffer(self): + """GroupedConvTraitConfig has double_smem_buffer field.""" + trait = GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=True, + num_groups_to_merge=2, + ) + self.assertTrue(trait.double_smem_buffer) + self.assertEqual(trait.num_groups_to_merge, 2) + + def test_grouped_conv_trait_config_has_num_groups_to_merge(self): + """GroupedConvTraitConfig has num_groups_to_merge field.""" + trait = GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=False, + num_groups_to_merge=4, + ) + self.assertEqual(trait.num_groups_to_merge, 4) + + def test_grouped_conv_trait_config_inherits_base_fields(self): + """GroupedConvTraitConfig inherits pipeline, epilogue, scheduler from base.""" + trait = GroupedConvTraitConfig( + "compv4", + "cshuffle", + "intrawave", + True, + True, + True, + double_smem_buffer=False, + num_groups_to_merge=1, + ) + self.assertEqual(trait.pipeline, "compv4") + self.assertEqual(trait.epilogue, "cshuffle") + self.assertEqual(trait.scheduler, "intrawave") + self.assertTrue(trait.pad_m) + self.assertTrue(trait.pad_n) + self.assertTrue(trait.pad_k) + + +# ============================================================================= +# TestTwoStageBwdWeightCodegen +# ============================================================================= + + +def _make_two_stage_config(): + """Helper: create a two-stage bwd_weight config.""" + return GroupedConvKernelConfig( + tile=TileConfig(16, 64, 64, 1, 4, 1, 16, 16, 32), + trait=GroupedConvTraitConfig( + pipeline="compv3", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + two_stage=True, + ), + variant=GroupedConvVariant.BACKWARD_WEIGHT, + ndim_spatial=2, + arch="gfx942", + ) + + +class TestTwoStageBwdWeightCodegen(unittest.TestCase): + """Tests for two-stage backward weight kernel generation.""" + + def test_kernel_name_contains_2stage(self): + config = _make_two_stage_config() + name = config.name("fp16") + self.assertIn("_2stage", name) + self.assertIn("bwd_weight", name) + + def test_single_stage_name_has_no_2stage(self): + config = _make_two_stage_config() + config.trait.two_stage = False + name = config.name("fp16") + self.assertNotIn("_2stage", name) + + def test_generate_contains_elementwise_include(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("elementwise.hpp", code) + + def test_generate_contains_workspace_type(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("WorkspaceDataType", code) + + def test_generate_contains_elementwise_kernel(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("ElementWiseKernel", code) + + def test_generate_contains_launch_kernel_time_mask(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("launch_kernel_time_mask", code) + + def test_generate_forces_vector_size_c_to_1(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("VectorSizeC_TwoStage = 1", code) + + def test_generate_contains_workspace_memset(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("hipMemsetAsync", code) + + def test_single_stage_does_not_contain_workspace(self): + config = _make_two_stage_config() + config.trait.two_stage = False + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertNotIn("WorkspaceDataType", code) + self.assertNotIn("ElementWiseKernel", code) + self.assertNotIn("launch_kernel_time_mask", code) + + def test_default_configs_include_two_stage(self): + from unified_grouped_conv_codegen import get_default_configs + + configs = get_default_configs( + arch="gfx942", + variants=[GroupedConvVariant.BACKWARD_WEIGHT], + ndims=[2], + ) + two_stage = [c for c in configs if c.trait.two_stage] + single_stage = [c for c in configs if not c.trait.two_stage] + self.assertGreater(len(two_stage), 0, "Should have two-stage configs") + self.assertGreater( + len(single_stage), 0, "Should still have single-stage configs" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher/tests/test_grouped_conv_config.cpp b/dispatcher/tests/test_grouped_conv_config.cpp new file mode 100644 index 0000000000..c9a1faeaf9 --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_config.cpp @@ -0,0 +1,112 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for GroupedConvConfig using assert() and std::cout + +#include "ck_tile/dispatcher/grouped_conv_config.hpp" +#include +#include + +using namespace ck_tile::dispatcher; + +void test_grouped_conv_direction_enum() +{ + std::cout << " test_grouped_conv_direction_enum... "; + assert(GroupedConvSignatureInfo::direction_str(GroupedConvDirection::FORWARD) == + std::string("fwd")); + assert(GroupedConvSignatureInfo::direction_str(GroupedConvDirection::BACKWARD_DATA) == + std::string("bwd_data")); + assert(GroupedConvSignatureInfo::direction_str(GroupedConvDirection::BACKWARD_WEIGHT) == + std::string("bwd_weight")); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_signature_info() +{ + std::cout << " test_grouped_conv_signature_info... "; + GroupedConvSignatureInfo sig; + assert(sig.spatial_dim == 2); + assert(sig.direction == GroupedConvDirection::FORWARD); + assert(sig.in_type == "fp16"); + assert(sig.wei_type == "fp16"); + assert(sig.out_type == "fp16"); + assert(sig.acc_type == "fp32"); + assert(sig.num_groups == 1); + sig.in_type = "bf16"; + sig.num_groups = 4; + assert(sig.in_type == "bf16"); + assert(sig.num_groups == 4); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_algorithm_info() +{ + std::cout << " test_grouped_conv_algorithm_info... "; + GroupedConvAlgorithmInfo algo; + assert(algo.tile.m == 128); + assert(algo.tile.n == 128); + assert(algo.tile.k == 64); + assert(algo.pipeline == PipelineVersion::V4); + assert(algo.scheduler == PipelineScheduler::INTRAWAVE); + assert(GroupedConvAlgorithmInfo::pipeline_str(PipelineVersion::V4) == std::string("compv4")); + assert(GroupedConvAlgorithmInfo::scheduler_str(PipelineScheduler::INTRAWAVE) == + std::string("intrawave")); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_config() +{ + std::cout << " test_grouped_conv_config... "; + GroupedConvConfig cfg; + std::string name = cfg.name(); + assert(!name.empty()); + assert(name.find("grouped_conv_") != std::string::npos); + assert(name.find("fwd") != std::string::npos); + assert(name.find("fp16") != std::string::npos); + assert(name.find("2d") != std::string::npos); + + std::string brief = cfg.brief(); + assert(!brief.empty()); + assert(brief.find("2D") != std::string::npos || brief.find("Grouped") != std::string::npos); + + std::string detailed = cfg.detailed(); + assert(!detailed.empty()); + assert(detailed.find("Signature:") != std::string::npos); + assert(detailed.find("Algorithm:") != std::string::npos); + assert(detailed.find("Arch:") != std::string::npos); + std::cout << "PASSED\n"; +} + +void test_predefined_grouped_conv_configs() +{ + std::cout << " test_predefined_grouped_conv_configs... "; + configs::Memory mem_cfg; + assert(mem_cfg.algorithm.pipeline == PipelineVersion::MEMORY); + assert(mem_cfg.algorithm.tile.m == 128); + assert(mem_cfg.algorithm.tile.n == 32); + + configs::CompV3_Small compv3_small; + assert(compv3_small.algorithm.pipeline == PipelineVersion::V3); + assert(compv3_small.algorithm.tile.m == 16); + assert(compv3_small.algorithm.tile.n == 64); + + configs::CompV4 compv4; + assert(compv4.algorithm.pipeline == PipelineVersion::V4); + assert(compv4.algorithm.double_smem_buffer == true); + + configs::WMMA wmma_cfg; + assert(wmma_cfg.arch.name == "gfx1100"); + std::cout << "PASSED\n"; +} + +int main() +{ + std::cout << "\n=== Test Grouped Conv Config ===\n\n"; + test_grouped_conv_direction_enum(); + test_grouped_conv_signature_info(); + test_grouped_conv_algorithm_info(); + test_grouped_conv_config(); + test_predefined_grouped_conv_configs(); + std::cout << "\n=== All Tests Passed! ===\n\n"; + return 0; +} diff --git a/dispatcher/tests/test_grouped_conv_kernel_decl.cpp b/dispatcher/tests/test_grouped_conv_kernel_decl.cpp new file mode 100644 index 0000000000..7b28a451bc --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_kernel_decl.cpp @@ -0,0 +1,141 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for GroupedConvKernelDecl using assert() and std::cout + +#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp" +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_decl; + +void test_grouped_conv_signature_builder() +{ + std::cout << " test_grouped_conv_signature_builder... "; + GroupedConvSignature sig; + sig.dtype("fp16").layout("nhwc").conv_type("forward").dims(2).groups(4); + assert(sig.dtype_in_ == "fp16"); + assert(sig.dtype_wei_ == "fp16"); + assert(sig.dtype_out_ == "fp16"); + assert(sig.layout_ == "nhwc"); + assert(sig.conv_op_ == "forward"); + assert(sig.num_dims_ == 2); + assert(sig.groups_ == 4); + assert(sig.op_str() == "fwd"); + sig.conv_type("bwd_data"); + assert(sig.op_str() == "bwd_data"); + sig.conv_type("bwd_weight"); + assert(sig.op_str() == "bwd_weight"); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_algorithm_builder() +{ + std::cout << " test_grouped_conv_algorithm_builder... "; + GroupedConvAlgorithm algo; + algo.tile(128, 128, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv4") + .scheduler("intrawave"); + assert(algo.tile_m_ == 128); + assert(algo.tile_n_ == 128); + assert(algo.tile_k_ == 64); + assert(algo.wave_m_ == 2); + assert(algo.wave_n_ == 2); + assert(algo.warp_m_ == 32); + assert(algo.warp_n_ == 32); + assert(algo.pipeline_ == "compv4"); + assert(algo.scheduler_ == "intrawave"); + assert(!algo.needs_expansion()); + algo.wave_m_ = ANY_INT; + assert(algo.needs_wave_expansion()); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_kernel_decl() +{ + std::cout << " test_grouped_conv_kernel_decl... "; + GroupedConvSignature sig; + sig.dtype("fp16").layout("nhwc").conv_type("forward").dims(2); + GroupedConvAlgorithm algo; + algo.tile(128, 128, 64).wave(2, 2, 1).warp(32, 32, 16); + GroupedConvKernelDecl decl(sig, algo, "gfx942"); + std::string name = decl.name(); + assert(!name.empty()); + assert(name.find("grouped_conv_") != std::string::npos); + assert(name.find("fwd") != std::string::npos); + assert(name.find("fp16") != std::string::npos); + assert(name.find("128x128x64") != std::string::npos); + assert(!decl.has_wildcards()); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_kernel_set() +{ + std::cout << " test_grouped_conv_kernel_set... "; + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + assert(set.size() == 1); + set.add("fp16", "nhwc", "forward", 256, 256); + assert(set.size() == 2); + const auto& decls = set.declarations(); + assert(decls[0].algorithm.tile_n_ == 128); + assert(decls[0].algorithm.tile_k_ == 128); + assert(decls[1].algorithm.tile_n_ == 256); + assert(decls[1].algorithm.tile_k_ == 256); + set.tag("test_set"); + assert(set.tag() == "test_set"); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_kernel_set_merge() +{ + std::cout << " test_grouped_conv_kernel_set_merge... "; + GroupedConvKernelSet set1; + set1.add("fp16", "nhwc", "forward", 128, 128); + GroupedConvKernelSet set2; + set2.add("fp16", "nhwc", "forward", 256, 256); + set1.merge(set2); + assert(set1.size() == 2); + assert(set1.declarations()[0].algorithm.tile_n_ == 128); + assert(set1.declarations()[1].algorithm.tile_n_ == 256); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_kernel_set_registry() +{ + std::cout << " test_grouped_conv_kernel_set_registry... "; + auto& reg = GroupedConvKernelSetRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + reg.register_set("gconv_test", set); + assert(reg.has("gconv_test")); + assert(reg.size() >= 1); + + const auto& retrieved = reg.get("gconv_test"); + assert(retrieved.size() == 1); + + const auto& empty = reg.get("nonexistent"); + assert(empty.size() == 0); + + reg.clear(); + assert(!reg.has("gconv_test")); + std::cout << "PASSED\n"; +} + +int main() +{ + std::cout << "\n=== Test Grouped Conv Kernel Decl ===\n\n"; + test_grouped_conv_signature_builder(); + test_grouped_conv_algorithm_builder(); + test_grouped_conv_kernel_decl(); + test_grouped_conv_kernel_set(); + test_grouped_conv_kernel_set_merge(); + test_grouped_conv_kernel_set_registry(); + std::cout << "\n=== All Tests Passed! ===\n\n"; + return 0; +} diff --git a/dispatcher/tests/test_grouped_conv_problem.cpp b/dispatcher/tests/test_grouped_conv_problem.cpp new file mode 100644 index 0000000000..a6a4d8ba08 --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_problem.cpp @@ -0,0 +1,245 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for GroupedConvProblem using assert() and std::cout + +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include +#include +#include + +using namespace ck_tile::dispatcher; + +void test_grouped_conv_problem_defaults() +{ + std::cout << " test_grouped_conv_problem_defaults... "; + GroupedConvProblem p; + assert(p.N == 1); + assert(p.C == 64); + assert(p.K == 64); + assert(p.G == 1); + assert(p.Hi() == 28); + assert(p.Wi() == 28); + assert(p.Y() == 3); + assert(p.X() == 3); + assert(p.op == GroupedConvOp::Forward); + assert(p.stride[0] == 1 && p.stride[1] == 1 && p.stride[2] == 1); + assert(p.padding[0] == 0 && p.padding[1] == 0 && p.padding[2] == 0); + assert(p.dilation[0] == 1 && p.dilation[1] == 1 && p.dilation[2] == 1); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_2d() +{ + std::cout << " test_grouped_conv_problem_2d... "; + GroupedConvProblem p(4, 64, 128, 28, 28, 3, 3); + p.compute_output_size(); + assert(p.N == 4); + assert(p.C == 64); + assert(p.K == 128); + assert(p.Hi() == 28); + assert(p.Wi() == 28); + assert(p.Y() == 3); + assert(p.X() == 3); + assert(p.Ho() == 26); + assert(p.Wo() == 26); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_strided() +{ + std::cout << " test_grouped_conv_problem_strided... "; + GroupedConvProblem p; + p.N = 1; + p.C = 64; + p.K = 64; + p.G = 1; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 3, 3}; + p.stride = {1, 2, 2}; + p.padding = {0, 1, 1}; + p.dilation = {1, 1, 1}; + p.compute_output_size(); + assert(p.Ho() == 7); + assert(p.Wo() == 7); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_grouped() +{ + std::cout << " test_grouped_conv_problem_grouped... "; + GroupedConvProblem p; + p.N = 2; + p.C = 64; + p.K = 64; + p.G = 4; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 3, 3}; + p.stride = {1, 1, 1}; + p.padding = {0, 0, 0}; + p.dilation = {1, 1, 1}; + p.compute_output_size(); + assert(p.G == 4); + assert(p.C % p.G == 0); + assert(p.K % p.G == 0); + assert(p.is_valid()); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_depthwise() +{ + std::cout << " test_grouped_conv_problem_depthwise... "; + GroupedConvProblem p; + p.N = 2; + p.C = 64; + p.K = 64; + p.G = 64; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 3, 3}; + p.stride = {1, 1, 1}; + p.padding = {0, 0, 0}; + p.dilation = {1, 1, 1}; + p.compute_output_size(); + assert(p.is_depthwise()); + assert(p.G == p.C && p.G == p.K); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_pointwise() +{ + std::cout << " test_grouped_conv_problem_pointwise... "; + GroupedConvProblem p; + p.N = 2; + p.C = 64; + p.K = 128; + p.G = 1; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 1, 1}; + p.stride = {1, 1, 1}; + p.padding = {0, 0, 0}; + p.dilation = {1, 1, 1}; + p.compute_output_size(); + assert(p.is_pointwise()); + assert(p.Y() == 1 && p.X() == 1); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_flops() +{ + std::cout << " test_grouped_conv_problem_flops... "; + GroupedConvProblem p; + p.N = 2; + p.C = 64; + p.K = 64; + p.G = 1; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 3, 3}; + p.stride = {1, 1, 1}; + p.padding = {0, 0, 0}; + p.dilation = {1, 1, 1}; + p.compute_output_size(); + double flops = p.get_flops(); + assert(flops > 0); + assert(flops == 2.0 * p.N * p.K * p.Ho() * p.Wo() * (p.C / p.G) * p.Y() * p.X()); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_is_valid() +{ + std::cout << " test_grouped_conv_problem_is_valid... "; + GroupedConvProblem p; + p.N = 1; + p.C = 64; + p.K = 64; + p.G = 1; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 3, 3}; + p.compute_output_size(); + assert(p.is_valid()); + + p.N = 0; + assert(!p.is_valid()); + p.N = 1; + + p.C = 0; + assert(!p.is_valid()); + p.C = 64; + + p.K = 0; + assert(!p.is_valid()); + p.K = 64; + + p.G = 0; + assert(!p.is_valid()); + p.G = 1; + + p.C = 64; + p.K = 64; + p.G = 3; + assert(!p.is_valid()); + p.G = 4; + assert(p.is_valid()); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_builder() +{ + std::cout << " test_grouped_conv_problem_builder... "; + auto p = GroupedConvProblemBuilder() + .batch(8) + .channels(128, 256) + .groups(4) + .input_size(32, 32) + .filter_size(3, 3) + .stride(2, 2) + .padding(1, 1) + .dilation(1, 1) + .operation(GroupedConvOp::Forward) + .build(); + assert(p.N == 8); + assert(p.C == 128); + assert(p.K == 256); + assert(p.G == 4); + assert(p.Hi() == 32); + assert(p.Wi() == 32); + assert(p.Y() == 3); + assert(p.X() == 3); + assert(p.stride[1] == 2 && p.stride[2] == 2); + assert(p.padding[1] == 1 && p.padding[2] == 1); + assert(p.op == GroupedConvOp::Forward); + assert(p.is_valid()); + + bool threw = false; + try + { + (void)GroupedConvProblemBuilder() + .batch(0) + .channels(64, 64) + .groups(1) + .input_size(14, 14) + .filter_size(3, 3) + .build(); + } + catch(const std::invalid_argument&) + { + threw = true; + } + assert(threw); + std::cout << "PASSED\n"; +} + +int main() +{ + std::cout << "\n=== Test Grouped Conv Problem ===\n\n"; + test_grouped_conv_problem_defaults(); + test_grouped_conv_problem_2d(); + test_grouped_conv_problem_strided(); + test_grouped_conv_problem_grouped(); + test_grouped_conv_problem_depthwise(); + test_grouped_conv_problem_pointwise(); + test_grouped_conv_problem_flops(); + test_grouped_conv_problem_is_valid(); + test_grouped_conv_problem_builder(); + std::cout << "\n=== All Tests Passed! ===\n\n"; + return 0; +} diff --git a/dispatcher/tests/test_grouped_conv_registry.cpp b/dispatcher/tests/test_grouped_conv_registry.cpp new file mode 100644 index 0000000000..47d13a9997 --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_registry.cpp @@ -0,0 +1,230 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for GroupedConvRegistry and GroupedConvDispatcher using assert() and std::cout + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include +#include +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_decl; + +void test_grouped_conv_registry_basic() +{ + std::cout << " test_grouped_conv_registry_basic... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + reg.set_name("test_registry"); + assert(reg.name() == "test_registry"); + + assert(reg.size() == 0); + assert(reg.empty()); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_register_set() +{ + std::cout << " test_grouped_conv_registry_register_set... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + set.add("fp16", "nhwc", "forward", 256, 256); + + bool ok = reg.register_set(set); + assert(ok); + assert(reg.size() == 2); + assert(!reg.empty()); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_all_kernels() +{ + std::cout << " test_grouped_conv_registry_all_kernels... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + reg.register_set(set); + + auto all = reg.all_kernels(); + assert(all.size() == 1); + assert(all[0]->name().find("grouped_conv_") != std::string::npos); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_clear() +{ + std::cout << " test_grouped_conv_registry_clear... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + reg.register_set(set); + assert(reg.size() == 1); + + reg.clear(); + assert(reg.size() == 0); + assert(reg.empty()); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_thread_safe() +{ + std::cout << " test_grouped_conv_registry_thread_safe... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + const int num_threads = 4; + const int sets_per_thread = 10; + std::vector threads; + std::atomic success_count{0}; + + for(int t = 0; t < num_threads; t++) + { + threads.emplace_back([t, ®, &success_count]() { + for(int k = 0; k < sets_per_thread; k++) + { + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128 + t * 32 + k, 128); + if(reg.register_set(set)) + { + success_count++; + } + } + }); + } + + for(auto& th : threads) + th.join(); + + assert(reg.size() == num_threads * sets_per_thread); + assert(success_count.load() == num_threads * sets_per_thread); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_export_json() +{ + std::cout << " test_grouped_conv_registry_export_json... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + reg.register_set(set); + + std::string json = reg.export_json(false); + assert(!json.empty()); + assert(json.find("\"kernels\"") != std::string::npos); + assert(json.find("\"metadata\"") != std::string::npos); + assert(json.find("grouped_conv_") != std::string::npos); + + std::string json_stats = reg.export_json(true); + assert(json_stats.find("\"statistics\"") != std::string::npos); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_filter() +{ + std::cout << " test_grouped_conv_registry_filter... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + set.add("fp16", "nhwc", "forward", 256, 256); + set.add("bf16", "nhwc", "forward", 128, 128); + reg.register_set(set); + + auto fp16_only = + reg.filter([](const GroupedConvKernelInstance& k) { return k.key().dtype_in == "fp16"; }); + assert(fp16_only.size() == 2); + + auto large_tile = reg.filter([](const GroupedConvKernelInstance& k) { + return k.key().tile_m >= 256 || k.key().tile_n >= 256; + }); + assert(large_tile.size() >= 1); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_dispatcher_basic() +{ + std::cout << " test_grouped_conv_dispatcher_basic... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + reg.register_set(set); + + GroupedConvDispatcher dispatcher(®); + GroupedConvProblem problem = grouped_conv_utils::create_grouped_conv2d_problem( + 4, 64, 128, 28, 28, 3, 3, 1, 1, GroupedConvOp::Forward); + + float time = dispatcher.run(problem, nullptr); + assert(time >= 0.0f); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_dispatcher_select() +{ + std::cout << " test_grouped_conv_dispatcher_select... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + set.add("fp16", "nhwc", "forward", 256, 256); + reg.register_set(set); + + GroupedConvDispatcher dispatcher(®); + GroupedConvProblem problem = grouped_conv_utils::create_grouped_conv2d_problem( + 4, 64, 128, 28, 28, 3, 3, 1, 1, GroupedConvOp::Forward); + + const auto* selected = dispatcher.select(problem); + assert(selected != nullptr); + assert(selected->name().find("grouped_conv_") != std::string::npos); + assert(selected->matches(problem)); + + reg.clear(); + std::cout << "PASSED\n"; +} + +int main() +{ + std::cout << "\n=== Test Grouped Conv Registry ===\n\n"; + test_grouped_conv_registry_basic(); + test_grouped_conv_registry_register_set(); + test_grouped_conv_registry_all_kernels(); + test_grouped_conv_registry_clear(); + test_grouped_conv_registry_thread_safe(); + test_grouped_conv_registry_export_json(); + test_grouped_conv_registry_filter(); + test_grouped_conv_dispatcher_basic(); + test_grouped_conv_dispatcher_select(); + std::cout << "\n=== All Tests Passed! ===\n\n"; + return 0; +} diff --git a/dispatcher/tests/test_grouped_conv_utils.py b/dispatcher/tests/test_grouped_conv_utils.py new file mode 100644 index 0000000000..9d0638dc08 --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_utils.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +TDD tests for python/grouped_conv_utils.py -- grouped convolution Python utilities. + +Phase 1 TDD: tests written BEFORE implementation exists. +Run: python3 -m pytest tests/test_grouped_conv_utils.py -v +""" + +import sys +import unittest +from pathlib import Path + +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "python")) +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) + +from dispatcher_common import ValidationResultBase # noqa: E402 +from grouped_conv_utils import ( # noqa: E402 + GroupedConvValidationResult, + validate_grouped_conv_config, + auto_correct_grouped_conv_config, + get_grouped_conv_default_config, + GroupedConvDataType, + format_grouped_conv_summary, +) + + +# ============================================================================= +# VALID CONFIG FIXTURES +# ============================================================================= + + +def make_valid_grouped_conv_config(): + """Return a valid grouped conv config dict for gfx942.""" + return { + "tile_config": { + "tile_k": 128, + "tile_c": 128, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + }, + "trait_config": { + "pipeline": "compv4", + "epilogue": "cshuffle", + "scheduler": "intrawave", + }, + "variant": "2d_fwd", + "ndim_spatial": 2, + "arch": "gfx942", + "layout": "nhwgc", + "dtype": "fp16", + } + + +# ============================================================================= +# TestGroupedConvValidationResult +# ============================================================================= + + +class TestGroupedConvValidationResult(unittest.TestCase): + """Tests for GroupedConvValidationResult dataclass.""" + + def test_inherits_from_validation_result_base(self): + """GroupedConvValidationResult should inherit from ValidationResultBase.""" + self.assertTrue( + issubclass(GroupedConvValidationResult, ValidationResultBase), + "GroupedConvValidationResult must inherit from ValidationResultBase", + ) + + def test_valid_result_has_is_valid(self): + """Valid result has is_valid=True.""" + vr = GroupedConvValidationResult(is_valid=True) + self.assertTrue(vr.is_valid) + + def test_invalid_result_has_is_valid_false(self): + """Invalid result has is_valid=False.""" + vr = GroupedConvValidationResult(is_valid=False, errors=["bad config"]) + self.assertFalse(vr.is_valid) + + def test_has_errors_list(self): + """Result has errors list.""" + vr = GroupedConvValidationResult( + is_valid=False, + errors=["invalid wave", "invalid trait"], + ) + self.assertEqual(len(vr.errors), 2) + self.assertIn("invalid wave", vr.errors) + self.assertIn("invalid trait", vr.errors) + + def test_has_warnings_list(self): + """Result has warnings list.""" + vr = GroupedConvValidationResult( + is_valid=True, + warnings=["deprecated option"], + ) + self.assertEqual(len(vr.warnings), 1) + self.assertIn("deprecated option", vr.warnings) + + def test_has_suggested_fixes_dict(self): + """Result has suggested_fixes dict.""" + vr = GroupedConvValidationResult( + is_valid=False, + suggested_fixes={"wave_m": 2, "wave_n": 2}, + ) + self.assertIn("wave_m", vr.suggested_fixes) + self.assertEqual(vr.suggested_fixes["wave_m"], 2) + self.assertIn("wave_n", vr.suggested_fixes) + self.assertEqual(vr.suggested_fixes["wave_n"], 2) + + def test_default_empty_errors_warnings_fixes(self): + """Default result has empty errors, warnings, suggested_fixes.""" + vr = GroupedConvValidationResult(is_valid=True) + self.assertEqual(vr.errors, []) + self.assertEqual(vr.warnings, []) + self.assertEqual(vr.suggested_fixes, {}) + + +# ============================================================================= +# TestValidateGroupedConvConfig +# ============================================================================= + + +class TestValidateGroupedConvConfig(unittest.TestCase): + """Tests for validate_grouped_conv_config.""" + + def test_valid_config_passes(self): + """Valid config should pass validation.""" + config = make_valid_grouped_conv_config() + result = validate_grouped_conv_config(config) + self.assertTrue(result.is_valid, f"Expected valid, got errors: {result.errors}") + self.assertEqual(result.errors, []) + + def test_invalid_wave_config_fails(self): + """Invalid wave config should fail validation.""" + config = make_valid_grouped_conv_config() + config["tile_config"]["wave_m"] = 3 + config["tile_config"]["wave_n"] = 3 + result = validate_grouped_conv_config(config) + self.assertFalse(result.is_valid) + self.assertGreater(len(result.errors), 0) + error_str = " ".join(result.errors).lower() + self.assertIn("wave", error_str) + + def test_invalid_trait_fails(self): + """Invalid trait combination should fail validation.""" + config = make_valid_grouped_conv_config() + config["trait_config"]["pipeline"] = "compv4" + config["trait_config"]["epilogue"] = "cshuffle" + config["trait_config"]["scheduler"] = "interwave" # Invalid combo + result = validate_grouped_conv_config(config) + self.assertFalse(result.is_valid) + self.assertGreater(len(result.errors), 0) + error_str = " ".join(result.errors).lower() + self.assertIn("trait", error_str) + + def test_missing_fields_fails(self): + """Config with missing required fields should fail validation.""" + config = {"arch": "gfx942"} # Missing tile_config, trait_config, etc. + result = validate_grouped_conv_config(config) + self.assertFalse(result.is_valid) + self.assertGreater(len(result.errors), 0) + + +# ============================================================================= +# TestAutoCorrectGroupedConvConfig +# ============================================================================= + + +class TestAutoCorrectGroupedConvConfig(unittest.TestCase): + """Tests for auto_correct_grouped_conv_config.""" + + def test_invalid_wave_gets_corrected(self): + """Invalid wave config should be auto-corrected.""" + config = make_valid_grouped_conv_config() + config["tile_config"]["wave_m"] = 3 + config["tile_config"]["wave_n"] = 3 + corrected, result = auto_correct_grouped_conv_config(config) + self.assertIsInstance(corrected, dict) + self.assertIsInstance(result, GroupedConvValidationResult) + # Corrected wave should be valid for arch + wave_m = corrected.get("tile_config", {}).get("wave_m") + wave_n = corrected.get("tile_config", {}).get("wave_n") + self.assertIn(wave_m, [1, 2, 4]) + self.assertIn(wave_n, [1, 2, 4]) + + def test_invalid_trait_gets_corrected(self): + """Invalid trait combination should be auto-corrected.""" + config = make_valid_grouped_conv_config() + config["trait_config"]["scheduler"] = "interwave" + config["trait_config"]["pipeline"] = "compv4" + config["trait_config"]["epilogue"] = "cshuffle" + corrected, result = auto_correct_grouped_conv_config(config) + self.assertIsInstance(corrected, dict) + self.assertIsInstance(result, GroupedConvValidationResult) + # Scheduler should be corrected to intrawave for compv4+cshuffle + scheduler = corrected.get("trait_config", {}).get("scheduler") + self.assertEqual(scheduler, "intrawave") + + +# ============================================================================= +# TestGetGroupedConvDefaultConfig +# ============================================================================= + + +class TestGetGroupedConvDefaultConfig(unittest.TestCase): + """Tests for get_grouped_conv_default_config.""" + + def test_returns_config(self): + """Should return a GroupedConvKernelConfig (or dict via to_dict).""" + config = get_grouped_conv_default_config("2d_fwd") + # Accepts both dataclass and dict + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIsInstance(d, dict) + + def test_has_tile_config(self): + """Returned config has tile_config key.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("tile_config", d) + self.assertIsInstance(d["tile_config"], dict) + + def test_has_trait_config(self): + """Returned config has trait_config key.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("trait_config", d) + self.assertIsInstance(d["trait_config"], dict) + + def test_has_variant(self): + """Returned config has variant.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("variant", d) + + def test_has_ndim_spatial(self): + """Returned config has ndim_spatial.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("ndim_spatial", d) + + def test_has_arch(self): + """Returned config has arch.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("arch", d) + + def test_has_layout(self): + """Returned config has layout.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("layout", d) + + +# ============================================================================= +# TestGroupedConvDataType +# ============================================================================= + + +class TestGroupedConvDataType(unittest.TestCase): + """Tests for GroupedConvDataType enum.""" + + def test_fp16_exists(self): + """GroupedConvDataType has FP16.""" + self.assertIsNotNone(GroupedConvDataType.FP16) + + def test_bf16_exists(self): + """GroupedConvDataType has BF16.""" + self.assertIsNotNone(GroupedConvDataType.BF16) + + def test_fp32_exists(self): + """GroupedConvDataType has FP32.""" + self.assertIsNotNone(GroupedConvDataType.FP32) + + def test_fp8_exists(self): + """GroupedConvDataType has FP8.""" + self.assertIsNotNone(GroupedConvDataType.FP8) + + def test_bf8_exists(self): + """GroupedConvDataType has BF8.""" + self.assertIsNotNone(GroupedConvDataType.BF8) + + def test_int8_exists(self): + """GroupedConvDataType has INT8.""" + self.assertIsNotNone(GroupedConvDataType.INT8) + + def test_enum_values_unique(self): + """All enum values should be unique.""" + values = [ + GroupedConvDataType.FP16, + GroupedConvDataType.BF16, + GroupedConvDataType.FP32, + GroupedConvDataType.FP8, + GroupedConvDataType.BF8, + GroupedConvDataType.INT8, + ] + self.assertEqual(len(values), len(set(values))) + + +# ============================================================================= +# TestFormatGroupedConvSummary +# ============================================================================= + + +class TestFormatGroupedConvSummary(unittest.TestCase): + """Tests for format_grouped_conv_summary.""" + + def test_returns_non_empty_string(self): + """Should return a non-empty string.""" + config = make_valid_grouped_conv_config() + summary = format_grouped_conv_summary(config) + self.assertIsInstance(summary, str) + self.assertGreater(len(summary), 0) + + def test_contains_key_info(self): + """Summary should contain key config info (variant, arch, layout, dtype).""" + config = make_valid_grouped_conv_config() + summary = format_grouped_conv_summary(config) + # Should mention at least some of: variant, arch, layout, dtype + summary_lower = summary.lower() + has_key_info = ( + "2d" in summary_lower + or "fwd" in summary_lower + or "gfx" in summary_lower + or "nhwgc" in summary_lower + or "fp16" in summary_lower + ) + self.assertTrue( + has_key_info, + f"Summary should contain key info, got: {summary}", + ) + + def test_empty_config_returns_something(self): + """Empty or minimal config should still return a string.""" + summary = format_grouped_conv_summary({}) + self.assertIsInstance(summary, str) + self.assertGreaterEqual(len(summary), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher/tests/test_problem_extended.cpp b/dispatcher/tests/test_problem_extended.cpp index 21ea545292..ba6068e3ee 100644 --- a/dispatcher/tests/test_problem_extended.cpp +++ b/dispatcher/tests/test_problem_extended.cpp @@ -19,7 +19,7 @@ class ProblemDimensionInferenceTest : public ::testing::Test TEST_F(ProblemDimensionInferenceTest, FromAB_Basic) { - // A: M×K (1024×512), B: K×N (512×2048) + // A: MxK (1024x512), B: KxN (512x2048) auto problem = Problem::from_ab(1024, 512, 512, 2048); EXPECT_EQ(problem.M, 1024); @@ -30,7 +30,7 @@ TEST_F(ProblemDimensionInferenceTest, FromAB_Basic) TEST_F(ProblemDimensionInferenceTest, FromDimensions_Valid) { - // A: 1024×512, B: 512×2048, C: 1024×2048 + // A: 1024x512, B: 512x2048, C: 1024x2048 auto problem = Problem::from_dimensions(1024, 512, 512, 2048, 1024, 2048); EXPECT_EQ(problem.M, 1024); @@ -55,7 +55,7 @@ TEST_F(ProblemDimensionInferenceTest, FromShapes_WithC) TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedA) { - // A stored as K×M (transposed) + // A stored as KxM (transposed) TensorShape A{512, 1024, true}; TensorShape B{512, 2048, false}; TensorShape C{1024, 2048, false}; @@ -70,7 +70,7 @@ TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedA) TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedB) { TensorShape A{1024, 512, false}; - // B stored as N×K (transposed) + // B stored as NxK (transposed) TensorShape B{2048, 512, true}; TensorShape C{1024, 2048, false}; diff --git a/dispatcher/tests/test_real_kernel_multi_size.cpp b/dispatcher/tests/test_real_kernel_multi_size.cpp index f23f684631..79282da557 100644 --- a/dispatcher/tests/test_real_kernel_multi_size.cpp +++ b/dispatcher/tests/test_real_kernel_multi_size.cpp @@ -187,7 +187,7 @@ int main() for(const auto& r : results) { char size_str[32]; - snprintf(size_str, sizeof(size_str), "%4d×%4d×%4d", r.M, r.N, r.K); + snprintf(size_str, sizeof(size_str), "%4dx%4dx%4d", r.M, r.N, r.K); printf(" %-14s | %9.4f | %6.2f | %7.2f%% | %s\n", size_str, diff --git a/dispatcher/tests/test_real_kernel_performance.cpp b/dispatcher/tests/test_real_kernel_performance.cpp index ff3d635968..29c7c80ac3 100644 --- a/dispatcher/tests/test_real_kernel_performance.cpp +++ b/dispatcher/tests/test_real_kernel_performance.cpp @@ -144,7 +144,7 @@ int main() all_passed = all_passed && passed; char size_label[32]; - snprintf(size_label, sizeof(size_label), "%s %d³", label, M); + snprintf(size_label, sizeof(size_label), "%s %d^3", label, M); printf(" %-9s | %9.4f | %6.2f | %9.1f | %s\n", size_label,