[TILE ENGINE] Restructure to Base class of GEMM (#3434)

This commit is contained in:
Thrupti Raj Lakshmana Gowda
2025-12-19 09:53:56 -06:00
committed by GitHub
parent 0fd2b2f045
commit e22622f0ec
41 changed files with 2246 additions and 3458 deletions

View File

@@ -1,310 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)")
set(GEMM_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM (semicolon-separated)")
set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
option(ENABLE_CCACHE_GEMM "Enable ccache for GEMM ops compilation" OFF)
# Store the directory path for use in functions
set(GEMM_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
# Function to create individual GEMM targets
function(create_individual_gemm_target datatype layout trait tile_config config_json)
# Use the parent scope GEMM_GPU_TARGETS_INDIVIDUAL variable
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping individual GEMM target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets")
return()
endif()
# Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k
# First split by underscore to get three groups
string(REPLACE "_" ";" config_groups ${tile_config})
list(GET config_groups 0 tile_dims) # e.g., 256x256x32
list(GET config_groups 1 warp_dims) # e.g., 4x1x1
list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16
# Parse tile dimensions
string(REPLACE "x" ";" tile_parts ${tile_dims})
list(GET tile_parts 0 tile_m)
list(GET tile_parts 1 tile_n)
list(GET tile_parts 2 tile_k)
# Parse warp dimensions
string(REPLACE "x" ";" warp_parts ${warp_dims})
list(GET warp_parts 0 warp_m)
list(GET warp_parts 1 warp_n)
list(GET warp_parts 2 warp_k)
# Parse warp tile dimensions
string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims})
list(GET warp_tile_parts 0 warp_tile_m)
list(GET warp_tile_parts 1 warp_tile_n)
list(GET warp_tile_parts 2 warp_tile_k)
set(target_name "benchmark_gemm_${datatype}_${layout}_${trait}_${tile_config}")
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Generate the single instance header for this kernel
set(instance_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
# Add custom command to generate the header file at build time
add_custom_command(
OUTPUT ${instance_header}
COMMAND ${Python3_EXECUTABLE} ${GEMM_SOURCE_DIR}/gemm_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${config_json}
--gen_single
--kernel_name "gemm_${datatype}_${layout}_${trait}_${tile_config}"
--tile_config "${tile_config}"
--trait_combo "${trait}"
--gpu_target "${GEMM_GPU_TARGETS_INDIVIDUAL}"
DEPENDS ${GEMM_SOURCE_DIR}/gemm_instance_builder.py ${config_json}
COMMENT "Generating ${instance_header}"
)
# Create the executable
add_executable(${target_name}
# to save build time, exclude the target from "all" target of "gemm" directory and its ancestors
EXCLUDE_FROM_ALL
${GEMM_SOURCE_DIR}/gemm_benchmark_single.cpp
${instance_header}
)
# Set GPU architectures
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS_INDIVIDUAL})
# Set compile definitions
target_compile_definitions(${target_name} PRIVATE
GEMM_SINGLE_INSTANCE_HPP="${instance_header}"
)
# Include directories
target_include_directories(${target_name} PRIVATE
${GEMM_SOURCE_DIR}
${working_path}
)
# Compile options
target_compile_options(${target_name} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
-include ${instance_header}
)
# Add to collection targets
add_dependencies(benchmark_gemm_all ${target_name})
add_dependencies(benchmark_gemm_${datatype} ${target_name})
add_dependencies(benchmark_gemm_${layout} ${target_name})
add_dependencies(benchmark_gemm_${datatype}_${layout} ${target_name})
# Add to trait-specific targets
string(REPLACE "_" ";" trait_parts ${trait})
list(GET trait_parts 0 pipeline)
list(GET trait_parts 1 epilogue)
list(GET trait_parts 2 scheduler)
add_dependencies(benchmark_gemm_${pipeline}_pipeline ${target_name})
add_dependencies(benchmark_gemm_${epilogue}_epilogue ${target_name})
add_dependencies(benchmark_gemm_${scheduler}_scheduler ${target_name})
endfunction()
# Function to build individual GEMM targets
function(build_individual_gemm_targets datatype layout)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Choose config file
# Priority order:
# 1. Environment variable GEMM_CONFIG_FILE
# 2. CMake variable GEMM_CONFIG_FILE
# 3. Default based on layout
# Check environment variable first
if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "")
set(config_filename "$ENV{GEMM_CONFIG_FILE}")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
message(VERBOSE " Using config from environment variable: ${config_filename}")
elseif(NOT "${GEMM_CONFIG_FILE}" STREQUAL "")
# Use CMake variable if set
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_CONFIG_FILE}")
message(VERBOSE " Using custom config: ${GEMM_CONFIG_FILE}")
else()
# Use default config for all layouts
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
message(VERBOSE " Using default config for layout ${layout}")
endif()
# Check if config file exists
if(NOT EXISTS ${json_blob})
message(FATAL_ERROR "Config file not found: ${json_blob}")
endif()
# Determine number of workers for parallel generation
if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL})
set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL})
else()
# Use processor count but limit to avoid memory issues
cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES)
math(EXPR num_workers "${num_cores}")
if(num_workers GREATER 8)
set(num_workers 8)
endif()
endif()
# Generate individual kernel files using parallel version
message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
message(VERBOSE " Working path: ${working_path}")
message(VERBOSE " Config file: ${json_blob}")
message(VERBOSE " Python executable: ${Python3_EXECUTABLE}")
message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py")
# Create working directory first
file(MAKE_DIRECTORY ${working_path})
message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--gpu_target ${GEMM_GPU_TARGETS_INDIVIDUAL}
--list_kernels ")
# First, just list the kernels (fast operation)
message(VERBOSE " Listing kernel configurations...")
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--gpu_target ${GEMM_GPU_TARGETS_INDIVIDUAL}
--list_kernels
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
ERROR_VARIABLE list_error
)
if(NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
endif()
# Read kernel count
if(EXISTS ${working_path}/gemm_kernel_count.txt)
file(READ ${working_path}/gemm_kernel_count.txt kernel_count)
string(STRIP "${kernel_count}" kernel_count)
message(VERBOSE " Found ${kernel_count} kernel configurations")
else()
message(FATAL_ERROR "Kernel count file not found")
endif()
# Read kernel list and create targets
if(EXISTS ${working_path}/gemm_kernel_list.txt)
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
foreach(line IN LISTS kernel_lines)
# Parse line: kernel_name|tile_config|trait_combo
string(REPLACE "|" ";" parts "${line}")
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
# Create individual target
create_individual_gemm_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}")
endforeach()
else()
message(FATAL_ERROR "Kernel list file not found")
endif()
endfunction()
# Main build logic - Only individual builds supported
message(VERBOSE "=== Starting Tile Engine GEMM Configuration ===")
message(VERBOSE "GEMM_DATATYPE: ${GEMM_DATATYPE}")
message(VERBOSE "GEMM_LAYOUT: ${GEMM_LAYOUT}")
message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201
set(GEMM_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
list(APPEND GEMM_GPU_TARGETS_INDIVIDUAL ${target})
message(VERBOSE " Adding GPU target: ${target}")
endif()
endforeach()
# Skip build if no matching targets found
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(VERBOSE "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}")
# Enable parallel compilation optimizations
# Set up job pools for better parallel compilation control
set_property(GLOBAL PROPERTY JOB_POOLS
compile_heavy=4 # Limit heavy compilations to prevent OOM
compile_normal=16 # Allow more parallel normal compilations
)
# Enable compiler cache if available and explicitly requested
# Disabled by default due to permission issues in CI environments
if(ENABLE_CCACHE_GEMM)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
message(VERBOSE "Using ccache for faster compilation")
else()
message(WARNING "ccache requested but not found")
endif()
else()
message(VERBOSE "ccache disabled for GEMM ops (use -DENABLE_CCACHE_GEMM=ON to enable)")
endif()
# Create master collection targets
add_custom_target(benchmark_gemm_all)
# Create datatype collection targets
foreach(dt IN LISTS GEMM_DATATYPE)
add_custom_target(benchmark_gemm_${dt})
endforeach()
# Create layout collection targets
foreach(l IN LISTS GEMM_LAYOUT)
add_custom_target(benchmark_gemm_${l})
endforeach()
# Create combined collection targets
foreach(dt IN LISTS GEMM_DATATYPE)
foreach(l IN LISTS GEMM_LAYOUT)
add_custom_target(benchmark_gemm_${dt}_${l})
endforeach()
endforeach()
# Create trait-based collection targets
# These are common trait components used across all GEMM kernels
set(GEMM_PIPELINES "mem;compv3;compv4")
set(GEMM_EPILOGUES "default;cshuffle")
set(GEMM_SCHEDULERS "intrawave;interwave")
foreach(pipeline IN LISTS GEMM_PIPELINES)
add_custom_target(benchmark_gemm_${pipeline}_pipeline)
endforeach()
foreach(epilogue IN LISTS GEMM_EPILOGUES)
add_custom_target(benchmark_gemm_${epilogue}_epilogue)
endforeach()
foreach(scheduler IN LISTS GEMM_SCHEDULERS)
add_custom_target(benchmark_gemm_${scheduler}_scheduler)
endforeach()
# Build individual targets for each datatype/layout combination
foreach(dt IN LISTS GEMM_DATATYPE)
foreach(l IN LISTS GEMM_LAYOUT)
build_individual_gemm_targets(${dt} ${l})
endforeach()
endforeach()
endif()
add_subdirectory(gemm_universal)
add_subdirectory(gemm_multi_d)
add_subdirectory(gemm_preshuffle)

View File

@@ -1,442 +0,0 @@
# CK Tile Engine GEMM Operations
## Overview
The CK Tile Engine GEMM module provides a comprehensive system for generating, building, and benchmarking GEMM (General Matrix Multiplication) kernels with various configurations. It supports multiple data types, layouts, and optimization strategies. The system has evolved from a monolithic build approach (where all kernels compile into a single executable) to a more flexible individual kernel compilation system, providing better build parallelism and targeted testing capabilities.
## Table of Contents
1. [Build System Architecture](#build-system-architecture)
2. [Build Instructions](#build-instructions)
3. [Running Benchmarks](#running-benchmarks)
4. [Configuration System](#configuration-system)
5. [Scripts and Tools](#scripts-and-tools)
6. [Command Line Options](#command-line-options)
7. [Understanding Kernel Names](#understanding-kernel-names)
8. [Troubleshooting](#troubleshooting)
9. [Performance Tips](#performance-tips)
## Build System Architecture
### Individual Kernel Compilation (New Approach)
The new tile engine benchmark system compiles each kernel configuration into a separate executable. This provides:
- Better build parallelism
- Faster incremental builds
- More targeted testing
- Easier debugging of specific configurations
Each benchmark executable follows the naming pattern:
```
benchmark_gemm_<dtype>_<layout>_<config>_<tile_sizes>
```
### Monolithic Build (Legacy Approach)
The original system compiles all kernels into a single executable (`benchmark_gemm_[Datatype]_[Layout]`), which can then be filtered at runtime using command-line arguments.
## Build Instructions
### Prerequisites
- ROCm installation
- CMake 3.16 or higher
- C++17 compatible compiler
### Basic Build
```bash
# In the root of composable kernel, create build directory
mkdir build && cd build
# Configure with specific datatypes and layouts
# Replace [Arch] with your GPU architecture (e.g., gfx90a, gfx942)
# Replace [Datatype1;Datatype2;...] with datatypes (fp8, bf8, int8, fp16, bf16, fp32, fp64)
# Replace [Layout1;Layout2;...] with layouts (rcr, rrr, crr, ccr)
../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]" -DGEMM_LAYOUT="[Layout1;Layout2]"
# Build specific benchmarks
make benchmark_gemm_[Datatype1]_[Layout1] -j
```
### Configuration Options
The build system supports several configuration options:
#### Using Custom Config Files
```bash
# Method 1: CMake variable (config file must be in configs/ directory)
cmake -DGEMM_CONFIG_FILE=my_custom_config.json ...
# Method 2: Environment variable (takes precedence over CMake variable)
export GEMM_CONFIG_FILE=my_custom_config.json
cmake ...
```
#### Config File Priority Order
1. **Environment variable** `GEMM_CONFIG_FILE` (highest priority)
2. **CMake variable** `GEMM_CONFIG_FILE`
3. **Default config** (default_config.json for all layouts)
**Note**: All custom config files must be placed in the `tile_engine/ops/gemm/configs/` directory.
### Example Build Commands
```bash
# Build for gfx942 with fp8 and fp16 datatypes, rcr layout
mkdir build && cd build
../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16" -DGEMM_LAYOUT="rcr;ccr;rrr;crr"
make benchmark_gemm_fp8_rcr -j
make benchmark_gemm_fp16_rcr -j
```
### Building Individual Kernels
```bash
# Build a specific kernel configuration
make benchmark_gemm_fp8_rcr_compv4_default_intrawave_False_False_False_False_256x256x32_1x4x1_32x32x32
# Build all fp16 benchmarks in parallel
make -j$(nproc) $(make help | grep benchmark_gemm_fp16 | awk '{print $2}')
```
### Rebuilding After Configuration Changes
If you modify the configuration file, you must rebuild:
```bash
rm -rf tile_engine/ && make benchmark_gemm_[Datatype]_[Layout] -j
```
## Running Benchmarks
### Individual Kernel Execution
```bash
cd /path/to/build/directory
./bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 \
-m=512 -n=512 -k=512 -verify=1
```
### Monolithic Executable (Legacy)
```bash
# Run specific pipeline/scheduler/epilogue combination
./bin/benchmark_gemm_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=default
```
### Automated Testing
Use the provided test script to run multiple benchmarks:
```bash
cd /path/to/composable_kernel/tile_engine/ops/gemm
./test_benchmark.sh [build_directory]
```
## Configuration System
### Configuration Files
The system uses JSON configuration files to specify kernel parameters:
- `configs/default_config.json` - Default configurations for various datatypes
- `configs/user_provided_config.json` - User-customizable configurations
### Configuration Structure
```json
{
"tile_config": {
"tile_m": {"values": [256, 128]},
"tile_n": {"values": [256, 128]},
"tile_k": {"values": [64, 32]},
"warp_m": {"values": [2, 4]},
"warp_n": {"values": [2, 1]},
"warp_k": {"values": [1]},
"warp_tile_m": {"values": [32, 16]},
"warp_tile_n": {"values": [32, 16]},
"warp_tile_k": {"values": [16, 32]}
},
"trait_config": {
"pipeline": {"values": ["compv3", "compv4", "mem"]},
"scheduler": {"values": ["intrawave", "interwave"]},
"epilogue": {"values": ["default", "cshuffle"]},
"pad_m": {"values": [false]},
"pad_n": {"values": [false]},
"pad_k": {"values": [false]},
"persistent": {"values": [false]}
}
}
```
## Scripts and Tools
### Python Scripts
#### gemm_instance_builder.py
**Purpose**: Main kernel instance generation script that creates C++ kernel implementations based on configuration files.
**Key Features**:
- Generates individual kernel header files for separate compilation
- Supports multiple data types (fp16, fp8, bf16, fp32, fp64)
- Validates tile configurations for correctness
- Creates CMake integration files
**Usage**:
```bash
python gemm_instance_builder.py \
--working_path ./generated \
--datatype fp16 \
--layout rcr \
--config_json configs/user_provided_config.json \
--gen_all_individual
```
#### gemm_instance_builder_parallel.py
**Purpose**: Parallel version of the instance builder for faster generation of multiple kernel configurations.
**Features**:
- Multi-threaded kernel generation
- Improved performance for large configuration spaces
#### validation_utils.py
**Purpose**: Provides comprehensive validation functions for kernel configurations.
**Key Functions**:
- `is_tile_config_valid()` - Validates tile dimensions and alignments
- `is_trait_combination_valid()` - Checks if pipeline/epilogue/scheduler combinations are supported
- `validate_warp_tile_combination()` - GPU-specific warp tile validation
- `validate_lds_capacity()` - Ensures configurations fit in LDS memory
**Validation Checks**:
- Dimension alignment (tile dimensions must be divisible by warp dimensions)
- LDS capacity constraints
- GPU-specific warp tile support
- Unsupported trait combinations
#### test_validation.py
**Purpose**: Test suite for the validation logic to ensure correctness.
**Usage**:
```bash
python test_validation.py
```
**Tests**:
- Warp tile combination validation
- Trait combination validation
- Full tile configuration validation
#### gemm_benchmark.py
**Purpose**: Python script for running and analyzing GEMM benchmarks.
**Features**:
- Automated benchmark execution
- Performance data collection
- Result analysis and reporting
#### json_config.py
**Purpose**: Configuration file parsing and management.
**Features**:
- JSON configuration loading
- Default configuration handling
- Configuration validation
#### codegen_utils.py
**Purpose**: Utility functions for code generation.
**Features**:
- Template processing
- Code formatting utilities
- File generation helpers
### Shell Scripts
#### test_benchmark.sh
**Purpose**: Automated benchmark testing script that finds and runs all built benchmark executables.
**Features**:
- Automatic build directory detection
- Batch execution of multiple benchmarks
- CSV result collection
- Colored output for easy reading
- Example command generation
**Usage**:
```bash
# Auto-detect build directory
./test_benchmark.sh
# Specify build directory
./test_benchmark.sh /path/to/build/directory
```
**What it does**:
1. Finds all benchmark executables in the build directory
2. Runs each with multiple problem sizes (512, 1024, 2048)
3. Performs GPU verification
4. Saves results to timestamped CSV file
5. Provides summary statistics
## Command Line Options
All benchmark executables support the following options:
### Matrix Dimensions
- `-m=<value>` - M dimension (default: 3840)
- `-n=<value>` - N dimension (default: 4096)
- `-k=<value>` - K dimension (default: 2048)
### Strides
- `-stride_a=<value>` - Stride for matrix A (default: 0, auto-calculated)
- `-stride_b=<value>` - Stride for matrix B (default: 0, auto-calculated)
- `-stride_c=<value>` - Stride for matrix C (default: 0, auto-calculated)
### Verification
- `-verify=<0|1|2>` - Verification mode
- 0: No verification (default)
- 1: CPU verification
- 2: GPU verification
### Performance Testing
- `-warmup=<value>` - Warmup iterations (default: 50)
- `-repeat=<value>` - Benchmark iterations (default: 100)
- `-timer=<true|false>` - Use GPU timer (default: true)
- `-flush_cache=<true|false>` - Flush cache between runs (default: true)
- `-rotating_count=<value>` - Cache rotation count (default: 1000)
### Initialization
- `-init=<0|1|2>` - Tensor initialization method
- 0: Random values [-1, 1] (default)
- 1: Linear sequence (i % 17)
- 2: Constant value (1.0)
### Output Options
- `-log=<true|false>` - Enable verbose logging (default: false)
- `-metric=<0|1|2>` - Performance metric
- 0: Latency in ms (default)
- 1: TFLOPS
- 2: Bandwidth in GB/s
- `-json_output=<true|false>` - JSON format output (default: false)
- `-csv_filename=<filename>` - Save results to CSV
- `-csv_format=<simple|comprehensive>` - CSV format (default: comprehensive)
### Advanced Options
- `-split_k=<value>` - Split-K factor (default: 1)
- `-structured_sparsity=<true|false>` - Enable structured sparsity (default: false)
- `-pipeline=<compv3|compv4|mem>` - Pipeline type (default: compv3)
- `-scheduler=<intrawave|interwave>` - Scheduler type (default: intrawave)
- `-epilogue=<cshuffle|default>` - Epilogue type (default: cshuffle)
- `-pad_m=<true|false>` - Pad M dimension (default: false)
- `-pad_n=<true|false>` - Pad N dimension (default: false)
- `-pad_k=<true|false>` - Pad K dimension (default: false)
- `-persistent=<true|false>` - Use persistent kernel (default: false)
## Understanding Kernel Names
The kernel naming convention encodes the configuration:
```
benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16
^^^^ ^^^ ^^^^^^ ^^^^^^^ ^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^ ^^^^^^^ ^^^^^^^^^
| | | | | | | | |
| | | | | Padding & flags | | Warp tile
| | | | Scheduler | Thread tile
| | | Epilogue Block tile
| | Pipeline
| Layout (Row-Column-Row)
Data type
```
### Components:
- **Data type**: fp16, fp32, bf16, fp8, bf8, int8
- **Layout**: rcr (Row-Column-Row), rrr, crr, ccr
- **Pipeline**: mem, compv3, compv4
- **Epilogue**: default, cshuffle
- **Scheduler**: intrawave, interwave
- **Flags**: pad_m, pad_n, pad_k, persistent (4 boolean flags)
- **Tile sizes**: BlockTile x ThreadTile x WarpTile
## Troubleshooting
### Common Issues
1. **Kernel not found**
- Ensure the specific benchmark executable is built
- Check the build directory bin/ folder
2. **Verification failures**
- Try GPU verification (-verify=2) which may be more accurate
- Check data type compatibility
- Verify stride calculations
3. **Build failures**
- Check GPU architecture compatibility
- Ensure ROCm is properly installed
- Verify configuration file syntax
4. **Performance variations**
- Increase warmup iterations
- Disable CPU frequency scaling
- Use GPU timer for accurate measurements
### Debug Options
Enable verbose logging:
```bash
./bin/benchmark_gemm_... -log=true -verify=1
```
Test validation logic:
```bash
python test_validation.py
```
## Performance Tips
1. **Optimal Problem Sizes**: Use sizes that are multiples of tile dimensions
2. **Warmup**: Use at least 50-100 warmup iterations
3. **GPU Timer**: Always use `-timer=true` for accurate measurements
4. **Cache Management**: Enable cache flushing for consistent results
5. **Thread Affinity**: Set CPU affinity to reduce variation
## Integration Examples
### Python Integration
```python
import subprocess
import json
# Run benchmark with JSON output
result = subprocess.run([
'./bin/benchmark_gemm_fp16_rcr_...',
'-m=1024', '-n=1024', '-k=1024',
'-json_output=true'
], capture_output=True, text=True)
# Parse results
data = json.loads(result.stdout)
print(f"Performance: {data['tflops']} TFLOPS")
```
### Batch Testing Script
```bash
#!/bin/bash
SIZES="512 1024 2048 4096"
for size in $SIZES; do
echo "Testing ${size}x${size}x${size}"
./bin/benchmark_gemm_... -m=$size -n=$size -k=$size \
-verify=2 -csv_filename=results.csv
done
```
## Contributing
When adding new features or configurations:
1. Update validation logic in `validation_utils.py`
2. Add tests to `test_validation.py`
3. Update configuration examples
4. Document new command-line options
For more information about the Composable Kernel project, visit the main repository documentation.

View File

@@ -1,41 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
// Helper function to determine if a layout is row-major
template <typename Layout>
constexpr auto is_row_major(Layout)
{
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// Structure to hold kernel traits for dispatcher
struct KernelTraits
{
std::string pipeline; // compv3, compv4, mem
std::string scheduler; // intrawave, interwave
std::string epilogue; // cshuffle, default
bool pad_m;
bool pad_n;
bool pad_k;
bool persistent;
// Constructor with defaults
KernelTraits()
: pipeline("compv3"),
scheduler("intrawave"),
epilogue("cshuffle"),
pad_m(false),
pad_n(false),
pad_k(false),
persistent(false)
{
}
};

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,314 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(GEMM_MULTI_D_DATATYPE "fp16" CACHE STRING "List of datatypes for GEMM Multi D (semicolon-separated)")
set(GEMM_MULTI_D_LAYOUT "rcrr;rrrr;crrr;ccrr" CACHE STRING "List of layout for GEMM Multi D (semicolon-separated)")
set(GEMM_MULTI_D_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
set(GEMM_MULTI_D_ELEMENTWISE_FUNCTION "mul" CACHE STRING "Elementwise function")
option(ENABLE_CCACHE_GEMM_MULTI_D "Enable ccache for GEMM Multi D ops compilation" OFF)
# Store the directory path for use in functions
set(GEMM_MULTI_D_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
# Function to create individual GEMM Multi D targets
function(create_individual_gemm_multi_d_target datatype layout trait tile_config config_json)
# Use the parent scope GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL variable
if(NOT GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping individual GEMM Multi D target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets")
return()
endif()
# Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k
# First split by underscore to get three groups
string(REPLACE "_" ";" config_groups ${tile_config})
list(GET config_groups 0 tile_dims) # e.g., 256x256x32
list(GET config_groups 1 warp_dims) # e.g., 4x1x1
list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16
# Parse tile dimensions
string(REPLACE "x" ";" tile_parts ${tile_dims})
list(GET tile_parts 0 tile_m)
list(GET tile_parts 1 tile_n)
list(GET tile_parts 2 tile_k)
# Parse warp dimensions
string(REPLACE "x" ";" warp_parts ${warp_dims})
list(GET warp_parts 0 warp_m)
list(GET warp_parts 1 warp_n)
list(GET warp_parts 2 warp_k)
# Parse warp tile dimensions
string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims})
list(GET warp_tile_parts 0 warp_tile_m)
list(GET warp_tile_parts 1 warp_tile_n)
list(GET warp_tile_parts 2 warp_tile_k)
set(target_name "benchmark_gemm_multi_d_${datatype}_${layout}_${trait}_${tile_config}")
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Generate the single instance header for this kernel
set(instance_header "${working_path}/gemm_multi_d_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
# Add custom command to generate the header file at build time
add_custom_command(
OUTPUT ${instance_header}
COMMAND ${Python3_EXECUTABLE} ${GEMM_MULTI_D_SOURCE_DIR}/gemm_multi_d_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
--config_json ${config_json}
--gen_single
--kernel_name "gemm_multi_d_${datatype}_${layout}_${trait}_${tile_config}"
--tile_config "${tile_config}"
--trait_combo "${trait}"
--gpu_target "${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}"
DEPENDS ${GEMM_MULTI_D_SOURCE_DIR}/gemm_multi_d_instance_builder.py ${config_json}
COMMENT "Generating ${instance_header}"
)
# Create the executable
add_executable(${target_name}
EXCLUDE_FROM_ALL
${GEMM_MULTI_D_SOURCE_DIR}/gemm_multi_d_benchmark_single.cpp
${instance_header}
)
# Set GPU architectures
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL})
# Set compile definitions
target_compile_definitions(${target_name} PRIVATE
GEMM_MULTI_D_SINGLE_INSTANCE_HPP="${instance_header}"
)
# Include directories
target_include_directories(${target_name} PRIVATE
${GEMM_MULTI_D_SOURCE_DIR}
${working_path}
)
# Compile options
target_compile_options(${target_name} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
-include ${instance_header}
)
# Add to collection targets
add_dependencies(benchmark_gemm_multi_d_all ${target_name})
add_dependencies(benchmark_gemm_multi_d_${datatype} ${target_name})
add_dependencies(benchmark_gemm_multi_d_${layout} ${target_name})
add_dependencies(benchmark_gemm_multi_d_${datatype}_${layout} ${target_name})
# Add to trait-specific targets
string(REPLACE "_" ";" trait_parts ${trait})
list(GET trait_parts 0 pipeline)
list(GET trait_parts 1 epilogue)
list(GET trait_parts 2 scheduler)
add_dependencies(benchmark_gemm_multi_d_${pipeline}_pipeline ${target_name})
add_dependencies(benchmark_gemm_multi_d_${epilogue}_epilogue ${target_name})
add_dependencies(benchmark_gemm_multi_d_${scheduler}_scheduler ${target_name})
endfunction()
# Function to build individual GEMM Multi D targets
function(build_individual_gemm_multi_d_targets datatype layout)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Choose config file
# Priority order:
# 1. Environment variable GEMM_MULTI_D_CONFIG_FILE
# 2. CMake variable GEMM_MULTI_D_CONFIG_FILE
# 3. Default based on layout
# Check environment variable first
if(DEFINED ENV{GEMM_MULTI_D_CONFIG_FILE} AND NOT "$ENV{GEMM_MULTI_D_CONFIG_FILE}" STREQUAL "")
set(config_filename "$ENV{GEMM_MULTI_D_CONFIG_FILE}")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
message(VERBOSE " Using config from environment variable: ${config_filename}")
elseif(NOT "${GEMM_MULTI_D_CONFIG_FILE}" STREQUAL "")
# Use CMake variable if set
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_MULTI_D_CONFIG_FILE}")
message(VERBOSE " Using custom config: ${GEMM_MULTI_D_CONFIG_FILE}")
else()
# Use default config for all layouts
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
message(VERBOSE " Using default config for layout ${layout}")
endif()
# Check if config file exists
if(NOT EXISTS ${json_blob})
message(FATAL_ERROR "Config file not found: ${json_blob}")
endif()
# Determine number of workers for parallel generation
if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL})
set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL})
else()
# Use processor count but limit to avoid memory issues
cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES)
math(EXPR num_workers "${num_cores}")
if(num_workers GREATER 8)
set(num_workers 8)
endif()
endif()
# Generate individual kernel files using parallel version
message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
message(VERBOSE " Working path: ${working_path}")
message(VERBOSE " Config file: ${json_blob}")
message(VERBOSE " Python executable: ${Python3_EXECUTABLE}")
message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py")
# Create working directory first
file(MAKE_DIRECTORY ${working_path})
message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
--config_json ${json_blob}
--gpu_target ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}
--list_kernels ")
# First, just list the kernels (fast operation)
message(VERBOSE " Listing kernel configurations...")
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
--config_json ${json_blob}
--gpu_target ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}
--list_kernels
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
ERROR_VARIABLE list_error
)
if(NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
endif()
# Read kernel count
if(EXISTS ${working_path}/gemm_multi_d_kernel_count.txt)
file(READ ${working_path}/gemm_multi_d_kernel_count.txt kernel_count)
string(STRIP "${kernel_count}" kernel_count)
message(VERBOSE " Found ${kernel_count} kernel configurations")
else()
message(FATAL_ERROR "Kernel count file not found")
endif()
# Read kernel list and create targets
if(EXISTS ${working_path}/gemm_multi_d_kernel_list.txt)
file(STRINGS ${working_path}/gemm_multi_d_kernel_list.txt kernel_lines)
foreach(line IN LISTS kernel_lines)
# Parse line: kernel_name|tile_config|trait_combo
string(REPLACE "|" ";" parts "${line}")
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
# Create individual target
create_individual_gemm_multi_d_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}")
endforeach()
else()
message(FATAL_ERROR "Kernel list file not found")
endif()
endfunction()
# Main build logic - Only individual builds supported
message(VERBOSE "=== Starting Tile Engine GEMM Multi D Configuration ===")
message(VERBOSE "GEMM_MULTI_D_DATATYPE: ${GEMM_MULTI_D_DATATYPE}")
message(VERBOSE "GEMM_MULTI_D_LAYOUT: ${GEMM_MULTI_D_LAYOUT}")
message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx90a, gfx942, gfx950
set(GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
list(APPEND GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL ${target})
message(VERBOSE " Adding GPU target: ${target}")
endif()
endforeach()
# Skip build if no matching targets found
if(NOT GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine GEMM Multi D build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(VERBOSE "Building individual GEMM Multi D targets for GPU targets: ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}")
# Enable parallel compilation optimizations
# Set up job pools for better parallel compilation control
set_property(GLOBAL PROPERTY JOB_POOLS
compile_heavy=4 # Limit heavy compilations to prevent OOM
compile_normal=16 # Allow more parallel normal compilations
)
# Enable compiler cache if available and explicitly requested
# Disabled by default due to permission issues in CI environments
if(ENABLE_CCACHE_GEMM_MULTI_D)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
message(VERBOSE "Using ccache for faster compilation")
else()
message(WARNING "ccache requested but not found")
endif()
else()
message(VERBOSE "ccache disabled for GEMM Multi D ops (use -DENABLE_CCACHE_GEMM_MULTI_D=ON to enable)")
endif()
# Create master collection targets
add_custom_target(benchmark_gemm_multi_d_all)
# Create datatype collection targets
foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE)
add_custom_target(benchmark_gemm_multi_d_${dt})
endforeach()
# Create layout collection targets
foreach(l IN LISTS GEMM_MULTI_D_LAYOUT)
add_custom_target(benchmark_gemm_multi_d_${l})
endforeach()
# Create combined collection targets
foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE)
foreach(l IN LISTS GEMM_MULTI_D_LAYOUT)
add_custom_target(benchmark_gemm_multi_d_${dt}_${l})
endforeach()
endforeach()
# Create trait-based collection targets
# These are common trait components used across all GEMM Multi D kernels
set(GEMM_MULTI_D_PIPELINES "mem;compv3;compv4")
set(GEMM_MULTI_D_EPILOGUES "default;cshuffle")
set(GEMM_MULTI_D_SCHEDULERS "intrawave;interwave")
foreach(pipeline IN LISTS GEMM_MULTI_D_PIPELINES)
add_custom_target(benchmark_gemm_multi_d_${pipeline}_pipeline)
endforeach()
foreach(epilogue IN LISTS GEMM_MULTI_D_EPILOGUES)
add_custom_target(benchmark_gemm_multi_d_${epilogue}_epilogue)
endforeach()
foreach(scheduler IN LISTS GEMM_MULTI_D_SCHEDULERS)
add_custom_target(benchmark_gemm_multi_d_${scheduler}_scheduler)
endforeach()
# Build individual targets for each datatype/layout combination
foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE)
foreach(l IN LISTS GEMM_MULTI_D_LAYOUT)
build_individual_gemm_multi_d_targets(${dt} ${l})
endforeach()
endforeach()
endif()

View File

@@ -0,0 +1,232 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <string>
#include <fstream>
#include <stdexcept>
#include <iomanip>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "gemm_multi_d_common.hpp"
// Data types and Layouts are defined by the generated kernel headers
// No hardcoded type definitions here to avoid conflicts
enum class Metric
{
LATENCY = 0,
TFLOPS = 1,
BANDWIDTH = 2
};
inline constexpr auto get_metric_name(Metric m)
{
switch(m)
{
case Metric::LATENCY: return "latency";
case Metric::TFLOPS: return "tflops";
case Metric::BANDWIDTH: return "bandwidth";
default: throw std::invalid_argument("Unsupported metric type");
}
}
struct GemmMultiDProblem
{
int split_k_;
int m_, n_, k_;
int stride_a_, stride_b_, stride_d0_, stride_d1_, stride_c_;
std::string dtype_a_, dtype_b_, dtype_d0_, dtype_d1_, dtype_acc_, dtype_c_;
std::string layout_a_, layout_b_, layout_d0_, layout_d1_, layout_c_;
friend std::ostream& operator<<(std::ostream& os, const GemmMultiDProblem& problem)
{
os << "{\n"
<< " \"split_k\":" << problem.split_k_ << ",\n"
<< " \"m\":" << problem.m_ << ",\n"
<< " \"n\":" << problem.n_ << ",\n"
<< " \"k\":" << problem.k_ << ",\n"
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
<< " \"stride_d0\":" << problem.stride_d0_ << ",\n"
<< " \"stride_d1\":" << problem.stride_d1_ << ",\n"
<< " \"stride_c\":" << problem.stride_c_ << ",\n"
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
<< " \"dtype_d0\":\"" << problem.dtype_d0_ << "\",\n"
<< " \"dtype_d1\":\"" << problem.dtype_d1_ << "\",\n"
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
<< " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n"
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
<< " \"layout_d0\":\"" << problem.layout_d0_ << "\",\n"
<< " \"layout_d1\":\"" << problem.layout_d1_ << "\",\n"
<< " \"layout_c\":\"" << problem.layout_c_ << "\"" << "\n"
<< "}";
return os;
}
};
struct PerformanceResult
{
double latency_;
double tflops_;
double bandwidth_;
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
{
switch(m)
{
case Metric::LATENCY: return a.latency_ < b.latency_;
case Metric::TFLOPS: return a.tflops_ > b.tflops_;
case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_;
default: throw std::invalid_argument("Unsupported metric type");
}
}
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
{
os << "{\n"
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
<< ",\n"
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
<< "}";
return os;
}
};
struct KernelInstance
{
std::string name_;
GemmMultiDProblem problem_;
PerformanceResult perf_result_;
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
{
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
}
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
{
os << "{\n"
<< " \"name\": \"" << obj.name_ << "\",\n"
<< " \"problem\": " << obj.problem_ << ",\n"
<< " \"perf_result\": " << obj.perf_result_ << "\n"
<< "}";
return os;
}
};
struct Setting
{
int n_warmup_;
int n_repeat_;
bool is_gpu_timer_;
int verify_;
int init_method_;
bool log_;
std::string csv_filename_;
bool flush_cache_;
int rotating_count_;
bool json_output_;
};
inline std::string get_rocm_version()
{
std::ifstream version_file("/opt/rocm/.info/version");
if(version_file.is_open())
{
std::string version;
std::getline(version_file, version);
return version;
}
return "Unknown";
}
template <typename ADataType,
typename BDataType,
typename D0DataType,
typename AccDataType,
typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeTypeAB =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
using ComputeType =
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
/// @brief Function to compare the results of the device and host computations
bool compare(std::string instanceName,
ck_tile::index_t K,
ck_tile::index_t kbatch,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
{
const float max_accumulated_value =
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
const auto rtol_atol =
calculate_rtol_atol<ADataType, BDataType, D0DataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_result,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "For " << instanceName << " Relative error threshold is "
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
return pass;
}
/// @brief Function to get the kernel output with reference implementation on CPU/GPU
void gemm_multi_d_host_reference(int verify,
ck_tile::HostTensor<ADataType>& a_m_k,
ck_tile::HostTensor<BDataType>& b_k_n,
ck_tile::HostTensor<D0DataType>& d0_m_n,
ck_tile::HostTensor<D1DataType>& d1_m_n,
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
{
if(verify > 0)
{
// Currently supporting on CPU verification for Gemm Multi D
// e_m_n_host_result.SetZero();
ck_tile::reference_gemm_multiple_d<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ElementWiseFn>(
a_m_k, b_k_n, {d0_m_n, d1_m_n}, c_m_n_host_result);
}
}

View File

@@ -0,0 +1,682 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import sys
import json
import subprocess
import argparse
import csv
import time
from pathlib import Path
from typing import List, Dict, Tuple, Optional
class GemmMultiDBenchmark:
def __init__(self, build_dir: str, verbose: bool = False):
self.build_dir = Path(build_dir)
self.verbose = verbose
self.results = []
def discover_kernels(self) -> List[Path]:
"""Find all benchmark_gemm_multi_d_* executables in the build directory"""
bin_dir = self.build_dir / "bin"
if not bin_dir.exists():
print(f"Error: Binary directory {bin_dir} does not exist")
return []
kernels = list(bin_dir.glob("benchmark_gemm_multi_d_*"))
if self.verbose:
print(f"Found {len(kernels)} kernel executables")
for k in kernels:
print(f" - {k.name}")
return kernels
def extract_kernel_info(self, kernel_path: Path) -> Dict[str, str]:
"""Extract comprehensive kernel information from filename"""
name = kernel_path.stem
# Initialize with basic info
info = {
"executable": str(kernel_path),
"name": name,
"data_type": "unknown",
"layout": "unknown",
"pipeline": "unknown",
"scheduler": "unknown",
"epilogue": "unknown",
}
# Parse the kernel name pattern:
# benchmark_gemm_multi_d_fp16_rcr_mem_default_intrawave_False_False_False_False_False_256x256x32_2x2x1_4x64x16
parts = name.split("_")
if len(parts) >= 5:
# Extract data type (3rd part after benchmark_gemm_)
info["data_type"] = parts[4] if len(parts) > 4 else "unknown"
# Extract layout (4th part)
info["layout"] = parts[5] if len(parts) > 5 else "unknown"
# Extract pipeline (5th part)
info["pipeline"] = parts[6] if len(parts) > 6 else "unknown"
# Extract epilogue (6th part)
info["epilogue"] = parts[7] if len(parts) > 7 else "unknown"
# Extract scheduler (7th part)
info["scheduler"] = parts[8] if len(parts) > 8 else "unknown"
# Extract detailed configuration from the end of the name
config_info = self.parse_detailed_config(name)
info.update(config_info)
# Generate config ID
info["config_id"] = self.generate_config_id(info)
return info
def parse_detailed_config(self, kernel_name: str) -> Dict:
"""Parse detailed configuration from kernel name"""
config = {
"tile_sizes": {"tile_m": 0, "tile_n": 0, "tile_k": 0},
"warp_config": {"warp_m": 0, "warp_n": 0, "warp_k": 0},
"warp_tile": {"warp_tile_m": 0, "warp_tile_n": 0, "warp_tile_k": 0},
"optimization_flags": {
"pad_m": False,
"pad_n": False,
"pad_k": False,
"persistent": False,
},
}
# Split by underscore and look for patterns
parts = kernel_name.split("_")
# Look for boolean flags (sequence of True/False values)
bool_sequence = []
for i, part in enumerate(parts):
if part in ["True", "False"]:
bool_sequence.append(part == "True")
# Continue collecting consecutive boolean values
j = i + 1
while j < len(parts) and parts[j] in ["True", "False"]:
bool_sequence.append(parts[j] == "True")
j += 1
break
# Assign boolean flags if we found them
# Order: pad_m, pad_n, pad_k, persistent (4 flags total)
if len(bool_sequence) >= 4:
config["optimization_flags"]["pad_m"] = bool_sequence[0]
config["optimization_flags"]["pad_n"] = bool_sequence[1]
config["optimization_flags"]["pad_k"] = bool_sequence[2]
config["optimization_flags"]["persistent"] = bool_sequence[3]
# Look for tile size patterns (e.g., 256x256x32_2x2x1_4x64x16)
# The pattern is: tile_sizes_warp_config_warp_tile
dimension_groups = []
for part in parts:
if "x" in part and len(part.split("x")) == 3:
try:
dims = [int(x) for x in part.split("x")]
if all(d > 0 for d in dims):
dimension_groups.append(dims)
except ValueError:
continue
# Assign dimensions based on order and magnitude
if len(dimension_groups) >= 3:
# Sort by magnitude to identify: largest=tile_sizes, smallest=warp_config, middle=warp_tile
sorted_groups = sorted(dimension_groups, key=max, reverse=True)
# Largest dimensions = tile sizes
config["tile_sizes"]["tile_m"] = sorted_groups[0][0]
config["tile_sizes"]["tile_n"] = sorted_groups[0][1]
config["tile_sizes"]["tile_k"] = sorted_groups[0][2]
# Smallest dimensions = warp config
config["warp_config"]["warp_m"] = sorted_groups[2][0]
config["warp_config"]["warp_n"] = sorted_groups[2][1]
config["warp_config"]["warp_k"] = sorted_groups[2][2]
# Middle dimensions = warp tile
config["warp_tile"]["warp_tile_m"] = sorted_groups[1][0]
config["warp_tile"]["warp_tile_n"] = sorted_groups[1][1]
config["warp_tile"]["warp_tile_k"] = sorted_groups[1][2]
elif len(dimension_groups) == 2:
# If only 2 groups, assign based on magnitude
sorted_groups = sorted(dimension_groups, key=max, reverse=True)
# Larger = tile sizes
config["tile_sizes"]["tile_m"] = sorted_groups[0][0]
config["tile_sizes"]["tile_n"] = sorted_groups[0][1]
config["tile_sizes"]["tile_k"] = sorted_groups[0][2]
# Smaller = warp config
config["warp_config"]["warp_m"] = sorted_groups[1][0]
config["warp_config"]["warp_n"] = sorted_groups[1][1]
config["warp_config"]["warp_k"] = sorted_groups[1][2]
elif len(dimension_groups) == 1:
# Only one group - assume it's tile sizes
config["tile_sizes"]["tile_m"] = dimension_groups[0][0]
config["tile_sizes"]["tile_n"] = dimension_groups[0][1]
config["tile_sizes"]["tile_k"] = dimension_groups[0][2]
return config
def generate_config_id(self, info: Dict) -> str:
"""Generate a compact config ID from kernel info"""
# Create a compact identifier
parts = [
info.get("data_type", "unk"),
info.get("layout", "unk"),
info.get("pipeline", "unk"),
info.get("scheduler", "unk"),
]
# Add tile configuration if available
tile_sizes = info.get("tile_sizes", {})
if tile_sizes.get("tile_m", 0) > 0:
tile_str = (
f"{tile_sizes['tile_m']}x{tile_sizes['tile_n']}x{tile_sizes['tile_k']}"
)
parts.append(tile_str)
# Add warp config if available
warp_config = info.get("warp_config", {})
if warp_config.get("warp_m", 0) > 0:
warp_str = f"w{warp_config['warp_m']}x{warp_config['warp_n']}x{warp_config['warp_k']}"
parts.append(warp_str)
# Add warp tile if available
warp_tile = info.get("warp_tile", {})
if warp_tile.get("warp_tile_m", 0) > 0:
warp_tile_str = f"wt{warp_tile['warp_tile_m']}x{warp_tile['warp_tile_n']}x{warp_tile['warp_tile_k']}"
parts.append(warp_tile_str)
return "_".join(parts)
def run_kernel(self, kernel_path: Path, params: Dict[str, str]) -> Optional[Dict]:
"""Run a single kernel with given parameters and save output to individual JSON file"""
# Create results directory
results_dir = self.build_dir / "results"
results_dir.mkdir(exist_ok=True)
# Generate unique JSON filename for this kernel
json_file = results_dir / f"{kernel_path.stem}.json"
cmd = [str(kernel_path)]
# Add parameters
for key, value in params.items():
cmd.append(f"-{key}={value}")
# Add JSON output flag for clean JSON output
cmd.append("-json_output=true")
if self.verbose:
print(f"Running: {' '.join(cmd)}")
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
if result.returncode != 0:
print(f"Error running {kernel_path.name}: {result.stderr}")
return None
# Save raw output to individual JSON file
output = result.stdout.strip()
if output:
with open(json_file, "w") as f:
f.write(output)
# Parse the JSON file
return self.parse_json_file(json_file)
else:
print(f"No output from {kernel_path.name}")
return None
except subprocess.TimeoutExpired:
print(f"Timeout running {kernel_path.name}")
return None
except Exception as e:
print(f"Error running {kernel_path.name}: {e}")
return None
def parse_json_file(self, json_file: Path) -> Optional[Dict]:
"""Parse JSON data from individual kernel output file"""
try:
with open(json_file, "r") as f:
content = f.read().strip()
# Parse the JSON directly since executables produce clean JSON
data = json.loads(content)
# Return the complete JSON data as-is, just add some convenience fields
result = data.copy()
if "perf_result" in data:
perf = data["perf_result"]
# Add convenience fields for backward compatibility
result["time_ms"] = perf.get("latency(ms)", 0)
result["tflops"] = perf.get("tflops(TFlops)", 0)
result["bandwidth_gb_s"] = perf.get("bandwidth(GB/s)", 0)
return result
except json.JSONDecodeError as e:
if self.verbose:
print(f"Failed to parse JSON from {json_file}: {e}")
return None
except Exception as e:
if self.verbose:
print(f"Error reading JSON file {json_file}: {e}")
return None
def benchmark_problem_size(
self,
kernels: List[Path],
m: int,
n: int,
k: int,
split_k: int = 1,
verify: int = 0,
warmup: int = 50,
repeat: int = 100,
flush_cache: bool = True,
rotating_count: int = 1000,
) -> List[Dict]:
"""Benchmark all kernels for a specific problem size"""
results = []
params = {
"m": m,
"n": n,
"k": k,
"split_k": split_k,
"verify": verify,
"warmup": warmup,
"repeat": repeat,
"flush_cache": str(flush_cache).lower(),
"rotating_count": rotating_count,
}
print(f"\nBenchmarking M={m}, N={n}, K={k}, split_k={split_k}")
for kernel_path in kernels:
kernel_info = self.extract_kernel_info(kernel_path)
result = self.run_kernel(kernel_path, params)
if result:
# Create new structured result format
structured_result = {
"name": kernel_info["name"], # Add name field for compatibility
"config_id": kernel_info["config_id"],
"problem": result.get("problem", {}),
"perf_result": result.get("perf_result", {}),
"config": {
"data_type": kernel_info["data_type"],
"layout": kernel_info["layout"],
"pipeline": kernel_info["pipeline"],
"scheduler": kernel_info["scheduler"],
"epilogue": kernel_info["epilogue"],
"tile_sizes": kernel_info.get("tile_sizes", {}),
"warp_config": kernel_info.get("warp_config", {}),
"warp_tile": kernel_info.get("warp_tile", {}),
"optimization_flags": kernel_info.get("optimization_flags", {}),
},
"executable": kernel_info["executable"],
# Keep backward compatibility fields
"time_ms": result.get("time_ms", 0),
"tflops": result.get("tflops", 0),
"bandwidth_gb_s": result.get("bandwidth_gb_s", 0),
}
results.append(structured_result)
if self.verbose:
print(
f" {kernel_info['config_id']}: {structured_result['tflops']:.2f} TFLOPS, {structured_result['bandwidth_gb_s']:.2f} GB/s, {structured_result['time_ms']:.2f}ms"
)
return results
def find_best_kernel(
self, results: List[Dict], metric: str = "tflops"
) -> Optional[Dict]:
"""Find the best performing kernel based on metric"""
if not results:
return None
if metric == "tflops":
return max(results, key=lambda x: x.get("tflops", 0))
elif metric == "time_ms":
return min(results, key=lambda x: x.get("time_ms", float("inf")))
elif metric == "bandwidth_gb_s":
return max(results, key=lambda x: x.get("bandwidth_gb_s", 0))
else:
raise ValueError(f"Unknown metric: {metric}")
def benchmark_sweep(
self,
problem_sizes: List[Tuple[int, int, int]],
split_k_values: List[int] = [1],
verify: bool = False,
warmup: int = 50,
repeat: int = 100,
flush_cache: bool = True,
rotating_count: int = 1000,
) -> Dict:
"""Run comprehensive benchmark sweep"""
kernels = self.discover_kernels()
if not kernels:
print("No kernels found!")
return {}
all_results = []
best_kernels = {}
for m, n, k in problem_sizes:
for split_k in split_k_values:
results = self.benchmark_problem_size(
kernels,
m,
n,
k,
split_k,
verify=2 if verify else 0,
warmup=warmup,
repeat=repeat,
flush_cache=flush_cache,
rotating_count=rotating_count,
)
all_results.extend(results)
# Find best kernel for this configuration
best = self.find_best_kernel(results)
if best:
key = f"m{m}_n{n}_k{k}_splitk{split_k}"
best_kernels[key] = best
print(
f"Best for {key}: {best['name']} ({best['tflops']:.2f} TFLOPS, {best['bandwidth_gb_s']:.2f} GB/s, {best['time_ms']:.2f}ms)"
)
self.results = all_results
return best_kernels
def export_csv(self, filename: str):
"""Export all results to CSV"""
if not self.results:
print("No results to export")
return
# Get all unique keys from results
all_keys = set()
for result in self.results:
all_keys.update(result.keys())
# Sort keys for consistent output
fieldnames = sorted(all_keys)
with open(filename, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(self.results)
print(f"Results exported to {filename}")
def export_best_kernels(self, best_kernels: Dict, filename: str):
"""Export best kernel selections to file"""
with open(filename, "w") as f:
f.write("# Best kernel selections\n")
f.write(
"# Format: problem_size -> kernel_name (TFLOPS, bandwidth, latency)\n\n"
)
for key, kernel in sorted(best_kernels.items()):
f.write(
f"{key}: {kernel['name']} ({kernel['tflops']:.2f} TFLOPS, {kernel['bandwidth_gb_s']:.2f} GB/s, {kernel['time_ms']:.2f}ms)\n"
)
print(f"Best kernels exported to {filename}")
def export_json(self, filename: str, best_kernels: Dict = None):
"""Export all results and best kernels to JSON with comprehensive metadata"""
from datetime import datetime
# Calculate comprehensive summary statistics for all metrics
successful_results = [r for r in self.results if r.get("tflops", 0) > 0]
tflops_values = [r.get("tflops", 0) for r in successful_results]
bandwidth_values = [r.get("bandwidth_gb_s", 0) for r in successful_results]
latency_values = [
r.get("time_ms", 0) for r in successful_results if r.get("time_ms", 0) > 0
]
# Performance breakdown by kernel type
pipeline_stats = {}
scheduler_stats = {}
data_type_stats = {}
for result in successful_results:
# Get config info from the new structure
config = result.get("config", {})
# Pipeline statistics
pipeline = config.get("pipeline", "unknown")
if pipeline not in pipeline_stats:
pipeline_stats[pipeline] = {
"count": 0,
"avg_tflops": 0,
"best_tflops": 0,
}
pipeline_stats[pipeline]["count"] += 1
pipeline_stats[pipeline]["best_tflops"] = max(
pipeline_stats[pipeline]["best_tflops"], result.get("tflops", 0)
)
# Scheduler statistics
scheduler = config.get("scheduler", "unknown")
if scheduler not in scheduler_stats:
scheduler_stats[scheduler] = {
"count": 0,
"avg_tflops": 0,
"best_tflops": 0,
}
scheduler_stats[scheduler]["count"] += 1
scheduler_stats[scheduler]["best_tflops"] = max(
scheduler_stats[scheduler]["best_tflops"], result.get("tflops", 0)
)
# Data type statistics
data_type = config.get("data_type", "unknown")
if data_type not in data_type_stats:
data_type_stats[data_type] = {
"count": 0,
"avg_tflops": 0,
"best_tflops": 0,
}
data_type_stats[data_type]["count"] += 1
data_type_stats[data_type]["best_tflops"] = max(
data_type_stats[data_type]["best_tflops"], result.get("tflops", 0)
)
# Calculate averages for breakdown stats
for stats_dict, field_name in [
(pipeline_stats, "pipeline"),
(scheduler_stats, "scheduler"),
(data_type_stats, "data_type"),
]:
for key in stats_dict:
relevant_results = [
r
for r in successful_results
if r.get("config", {}).get(field_name, "unknown") == key
]
if relevant_results:
stats_dict[key]["avg_tflops"] = sum(
r.get("tflops", 0) for r in relevant_results
) / len(relevant_results)
output_data = {
"benchmark_metadata": {
"timestamp": datetime.now().isoformat(),
"total_kernels_tested": len(self.results),
"unique_kernels": len(
set(r.get("name", "unknown") for r in self.results)
),
"successful_runs": len(successful_results),
"failed_runs": len(self.results) - len(successful_results),
},
"performance_summary": {
"tflops_stats": {
"best": max(tflops_values, default=0),
"average": sum(tflops_values) / len(tflops_values)
if tflops_values
else 0,
"min": min(tflops_values, default=0),
"median": sorted(tflops_values)[len(tflops_values) // 2]
if tflops_values
else 0,
},
"bandwidth_stats": {
"best_gb_s": max(bandwidth_values, default=0),
"average_gb_s": sum(bandwidth_values) / len(bandwidth_values)
if bandwidth_values
else 0,
"min_gb_s": min(bandwidth_values, default=0),
"median_gb_s": sorted(bandwidth_values)[len(bandwidth_values) // 2]
if bandwidth_values
else 0,
},
"latency_stats": {
"best_ms": min(latency_values, default=0),
"average_ms": sum(latency_values) / len(latency_values)
if latency_values
else 0,
"max_ms": max(latency_values, default=0),
"median_ms": sorted(latency_values)[len(latency_values) // 2]
if latency_values
else 0,
},
"kernel_type_breakdown": {
"by_pipeline": pipeline_stats,
"by_scheduler": scheduler_stats,
"by_data_type": data_type_stats,
},
"total_problem_configurations": len(best_kernels)
if best_kernels
else 0,
},
"kernel_results": self.results,
"best_kernels_by_problem": best_kernels or {},
}
with open(filename, "w") as f:
json.dump(output_data, f, indent=2)
print(f"JSON results exported to {filename}")
print(f" - Total kernels: {len(self.results)}")
print(f" - Successful runs: {len(successful_results)}")
print(f" - Best TFLOPS: {max(tflops_values, default=0):.2f}")
print(f" - Best bandwidth: {max(bandwidth_values, default=0):.2f} GB/s")
print(f" - Best latency: {min(latency_values, default=0):.2f}ms")
def main():
parser = argparse.ArgumentParser(
description="GEMM Multi D Kernel Benchmarking Tool"
)
parser.add_argument(
"build_dir", help="Build directory containing kernel executables"
)
parser.add_argument(
"--problem-sizes",
nargs="+",
default=["1024,1024,1024", "2048,2048,2048", "4096,4096,4096"],
help="Problem sizes as M,N,K tuples",
)
parser.add_argument(
"--split-k", nargs="+", type=int, default=[1], help="Split-K values to test"
)
parser.add_argument("--verify", action="store_true", help="Enable verification")
parser.add_argument(
"--csv",
default="gemm_multi_d_benchmark_results.csv",
help="CSV output filename",
)
parser.add_argument(
"--best", default="best_kernels.txt", help="Best kernels output filename"
)
parser.add_argument("--verbose", action="store_true", help="Verbose output")
parser.add_argument(
"--warmup",
type=int,
default=50,
help="Number of warmup iterations (default: 50)",
)
parser.add_argument(
"--repeat",
type=int,
default=100,
help="Number of benchmark iterations (default: 100)",
)
parser.add_argument(
"--flush-cache",
action="store_true",
default=True,
help="Enable cache flushing (default: True)",
)
parser.add_argument(
"--rotating-count",
type=int,
default=1000,
help="Number of iterations to rotate cache (default: 1000)",
)
parser.add_argument("--json", help="JSON output filename (optional)")
args = parser.parse_args()
# Parse problem sizes
problem_sizes = []
for size_str in args.problem_sizes:
try:
m, n, k = map(int, size_str.split(","))
problem_sizes.append((m, n, k))
except ValueError:
print(f"Invalid problem size: {size_str}")
return 1
# Create benchmark instance
benchmark = GemmMultiDBenchmark(args.build_dir, verbose=args.verbose)
# Run benchmark sweep
print("Starting GEMM Multi D kernel benchmark sweep...")
start_time = time.time()
best_kernels = benchmark.benchmark_sweep(
problem_sizes=problem_sizes,
split_k_values=args.split_k,
verify=args.verify,
warmup=args.warmup,
repeat=args.repeat,
flush_cache=args.flush_cache,
rotating_count=args.rotating_count,
)
elapsed_time = time.time() - start_time
print(f"\nBenchmark completed in {elapsed_time:.2f} seconds")
# Export results
benchmark.export_csv(args.csv)
benchmark.export_best_kernels(best_kernels, args.best)
# Export JSON if requested
if args.json:
benchmark.export_json(args.json, best_kernels)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,170 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include <functional>
#include <tuple>
#include <exception>
#include <sstream>
#include <vector>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "gemm_multi_d_profiler.hpp"
#include "gemm_multi_d_common.hpp"
// The kernel header is included via the compile command line with -include flag
// It defines SelectedKernel struct and KERNEL_NAME
// DataTypeTraits are now defined in gemm_multi_d_common.hpp
// Create argument parser
inline auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.")
.insert("n", "4096", "The value for n dimension. Default is 4096.")
.insert("k", "2048", "The value for k dimension. Default is 2048.")
.insert("stride_a", "0", "The stride value for tensor A. Default is 0.")
.insert("stride_b", "0", "The stride value for tensor B. Default is 0.")
.insert("stride_ds", "0", "The stride value for tensor Ds . Default is 0.")
.insert("stride_c", "0", "The stride value for tensor C. Default is 0.")
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
.insert("verify",
"1",
"for validation on GPU. Default is 1, validation on CPU, as validation on GPU is "
"not supported.")
.insert("log",
"false",
"Whether output kernel instance information or not. Possible values are true or "
"false. Default is false")
.insert(
"warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.")
.insert(
"repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.")
.insert("timer",
"true",
"Whether if the timer is gpu timer or not. Possible values are false or true. "
"Default is true.")
.insert("init",
"0",
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
"for constant(1). Default is 0, random.")
.insert("flush_cache",
"true",
"To flush cache, possible values are true or false. "
"Default is false.")
.insert("rotating_count", "1000", "number of iterations to rotate the cache. default is 5.")
.insert("metric",
"0",
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
"tflops, or 2 for bandwidth. Default is 0, latency.")
.insert("csv_filename",
"",
"The filename of benchmark result. Default is empty (no CSV output).")
.insert("structured_sparsity",
"false",
"Whether use sparsity kernel or not. Possible values are true or false. Default is "
"false")
.insert("json_output",
"false",
"Whether to output results in JSON format only. Possible values are true or false. "
"Default is "
"false");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
void benchmark_single(const ck_tile::ArgParser& arg_parser)
{
// Use DataTypeTraits to get the actual type names from the generated header
// The generated header defines ADataType, BDataType, AccDataType, CDataType
std::string dtype_a = DataTypeTraits<ADataType>::name;
std::string dtype_b = DataTypeTraits<BDataType>::name;
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
std::string dtype_c = DataTypeTraits<CDataType>::name;
std::string dtype_d0 = DataTypeTraits<D0DataType>::name;
std::string dtype_d1 = DataTypeTraits<D1DataType>::name;
// Layout names from the layout types
std::string layout_a = ALayout::name;
std::string layout_b = BLayout::name;
std::string layout_c = CLayout::name;
std::string layout_d0 = D0Layout::name;
std::string layout_d1 = D1Layout::name;
// Create GemmMultiDProblem struct
GemmMultiDProblem gemm_multi_d_problem{arg_parser.get_int("split_k"),
arg_parser.get_int("m"),
arg_parser.get_int("n"),
arg_parser.get_int("k"),
arg_parser.get_int("stride_a"),
arg_parser.get_int("stride_b"),
arg_parser.get_int("stride_ds"),
arg_parser.get_int("stride_ds"),
arg_parser.get_int("stride_c"),
dtype_a,
dtype_b,
dtype_d0,
dtype_d1,
dtype_acc,
dtype_c,
layout_a,
layout_b,
layout_d0,
layout_d1,
layout_c};
// Create Setting struct
Setting setting{arg_parser.get_int("warmup"),
arg_parser.get_int("repeat"),
arg_parser.get_bool("timer"),
arg_parser.get_int("verify"),
arg_parser.get_int("init"),
arg_parser.get_bool("log"),
arg_parser.get_str("csv_filename"),
arg_parser.get_bool("flush_cache"),
arg_parser.get_int("rotating_count"),
arg_parser.get_bool("json_output")};
// Get the profiler instance
auto& profiler = GemmMultiDProfiler::instance(setting);
try
{
// Create a lambda that wraps the kernel launch
auto kernel_func = [](const ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args,
const ck_tile::stream_config& stream) {
return SelectedKernel::launch(args, stream);
};
// Benchmark the kernel
profiler.benchmark(gemm_multi_d_problem, kernel_func);
// Select best instance based on metric
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
}
catch(const std::exception& e)
{
std::cerr << "Benchmark failed: " << e.what() << std::endl;
}
}
int main(int argc, char* argv[])
{
try
{
auto [result, parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
benchmark_single(parser);
return 0;
}
catch(const std::exception& e)
{
std::cerr << "Error: " << e.what() << "\n";
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,100 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
//[TODO] This can be moved to commons
// DataTypeTraits for all supported types
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <>
struct DataTypeTraits<ck_tile::int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
// Helper function to determine if a layout is row-major
template <typename Layout>
constexpr auto is_row_major(Layout)
{
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// Structure to hold kernel traits for dispatcher
struct KernelTraits
{
std::string pipeline; // compv3, compv4, mem
std::string scheduler; // intrawave, interwave
std::string epilogue; // cshuffle, default
bool pad_m;
bool pad_n;
bool pad_k;
bool persistent;
// Constructor with defaults
KernelTraits()
: pipeline("compv3"),
scheduler("intrawave"),
epilogue("cshuffle"),
pad_m(false),
pad_n(false),
pad_k(false),
persistent(false)
{
}
};

View File

@@ -0,0 +1,330 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import os
import argparse
import importlib.util
import multiprocessing
import concurrent.futures
def _import_gemm_kernel_builder():
"""Import validation utilities from commons directory."""
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
# Load the module dynamically
spec = importlib.util.spec_from_file_location(
"gemm_instance_builder",
os.path.join(parent_dir, "gemm_instance_builder.py"),
)
gemm_builder_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(gemm_builder_module)
return gemm_builder_module.GemmKernelBuilder
GemmKernelBuilder = _import_gemm_kernel_builder()
class GemmMultiDKernelBuilder(GemmKernelBuilder):
def __init__(
self,
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
elementwise_function,
config_json=None,
):
super().__init__(
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
)
self.elementwise_function = elementwise_function
def _generate_all_individual(self, num_workers=None):
"""Generate individual kernel files for separate compilation with parallel processing"""
if num_workers is None:
num_workers = min(
multiprocessing.cpu_count(), 8
) # Limit to avoid memory issues
tile_configs = self._get_tile_configs()
trait_combos = self._generate_trait_combinations()
# Prepare work items for parallel processing
work_items = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
work_items.append(
(
tile_config,
trait_combo,
self.kernel_name_prefix,
self.working_path,
self.gpu_target,
self.datatype,
self.layout,
self.elementwise_function,
self.config_json,
)
)
print(
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
)
print(f" Tile configs: {len(tile_configs)}")
print(f" Trait combinations: {len(trait_combos)}")
print(f" Total kernels: {len(work_items)}")
# Show first few work items for debugging
if work_items:
print(" First work item example:")
tile_config, trait_combo = work_items[0][:2]
print(f" Tile config: {tile_config}")
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
# Process work items in parallel
kernel_list = []
completed = 0
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
# Submit all work items
print(f" Submitting {len(work_items)} tasks to executor...")
future_to_item = {
executor.submit(_generate_single_kernel_individual, item): item
for item in work_items
}
print(" All tasks submitted, waiting for completion...")
# Collect results with progress reporting
for future in concurrent.futures.as_completed(future_to_item):
completed += 1
if completed % 100 == 0 or completed == len(work_items):
print(
f" Progress: {completed}/{len(work_items)} kernels generated"
)
try:
result = future.result()
if result:
kernel_list.append(result)
except Exception as exc:
item = future_to_item[future]
print(f"Kernel generation failed for {item}: {exc}")
# Sort kernel list for consistent ordering
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
# Generate CMake include file for individual targets
self._generate_cmake_individual_targets(kernel_list)
print(
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
)
def _generate_single_kernel_individual(work_item):
"""Worker function to generate a single individual kernel file"""
(
tile_config,
trait_combo,
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
elementwise_function,
config_json,
) = work_item
# Create a temporary builder instance for this worker
builder = GemmMultiDKernelBuilder(
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
elementwise_function,
config_json,
)
try:
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
)
# Create simplified filename without the "gemm_multi_d_" prefix
# Remove "gemm_multi_d_" from the beginning of kernel_name for the filename
simplified_name = kernel_name
if simplified_name.startswith("gemm_multi_d_"):
simplified_name = simplified_name[
len(kernel_name_prefix) + 1 :
] # Remove "gemm_multi_d_" prefix
# Write individual header file
header_file = working_path / f"gemm_multi_d_single_{simplified_name}.hpp"
with open(header_file, "w") as f:
f.write(instance_code)
return (kernel_name, trait_combo, tile_config)
except Exception as e:
print(f"Error generating individual kernel: {e}")
return None
def main():
parser = argparse.ArgumentParser(
description="GEMM Multi D kernel instance builder with parallel support"
)
parser.add_argument("--working_path", required=True, help="Working directory path")
parser.add_argument("--gpu_target", required=True, help="GPU target architecture")
parser.add_argument(
"--datatype",
required=True,
choices=["fp16"],
help="Data type",
)
parser.add_argument(
"--layout",
required=True,
choices=["rcrr", "rrrr", "ccrr", "crrr"],
help="Matrix layout",
)
parser.add_argument(
"--elementwise_function",
required=True,
help="Specify what element wise function for D, e.g. mul, add, passthrough",
)
parser.add_argument("--config_json", help="Configuration JSON file")
parser.add_argument(
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
)
parser.add_argument(
"--gen_all_individual",
action="store_true",
help="Generate individual kernel files",
)
parser.add_argument(
"--gen_single", action="store_true", help="Generate a single kernel file"
)
parser.add_argument("--kernel_name", help="Kernel name for single generation")
parser.add_argument(
"--tile_config", help="Tile configuration string for single generation"
)
parser.add_argument(
"--trait_combo", help="Trait combination string for single generation"
)
parser.add_argument(
"--list_kernels",
action="store_true",
help="List kernel configurations without generating files",
)
args = parser.parse_args()
assert args.datatype in ["fp16"], (
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16])"
)
layout_parts = args.layout.lower()
assert len(layout_parts) == 4, (
f"Invalid layout string: {args.layout} (must be 4 characters like 'rcrr' where r stands for row major and c stands for column major)"
)
assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], (
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)"
)
assert layout_parts[2] == "r" and layout_parts[3] == "r", (
f"Invalid matrix_c or d dimension in layout: {layout_parts[2]} and {layout_parts[3]} (must be 'r' only as currently we are supporting only row major)"
)
# Elementwise function name validation
elementwise_function = args.elementwise_function.lower()
valid_functions = ["mul", "add", "passthrough"]
if elementwise_function not in valid_functions:
raise ValueError(
f"Invalid elementwise function: {elementwise_function}. "
f"Valid options are: {', '.join(valid_functions)}"
)
# Set the function name based on the elementwise function
if elementwise_function == "mul":
function_name = "MultiDMultiply"
elif elementwise_function == "add":
function_name = "MultiDAdd"
elif elementwise_function == "passthrough":
function_name = "PassThrough"
args.elementwise_function = function_name
# Create builder
kernel_name_prefix = "gemm_multi_d"
builder = GemmMultiDKernelBuilder(
kernel_name_prefix,
args.working_path,
args.gpu_target,
args.datatype,
args.layout,
args.elementwise_function,
args.config_json,
)
if args.list_kernels:
builder._list_kernels()
elif args.gen_single:
# Generate a single kernel file
if not args.kernel_name or not args.tile_config or not args.trait_combo:
parser.error(
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
)
# Parse tile config
tile_parts = args.tile_config.split("_")
tile_dims = tile_parts[0].split("x")
warp_dims = tile_parts[1].split("x")
warp_tile_dims = tile_parts[2].split("x")
tile_config = {
"tile_m": int(tile_dims[0]),
"tile_n": int(tile_dims[1]),
"tile_k": int(tile_dims[2]),
"warp_m": int(warp_dims[0]),
"warp_n": int(warp_dims[1]),
"warp_k": int(warp_dims[2]),
"warp_tile_m": int(warp_tile_dims[0]),
"warp_tile_n": int(warp_tile_dims[1]),
"warp_tile_k": int(warp_tile_dims[2]),
}
# Parse trait combo
trait_parts = args.trait_combo.split("_")
trait_combo = (
trait_parts[0], # pipeline
trait_parts[1], # epilogue
trait_parts[2], # scheduler
trait_parts[3], # pad_m
trait_parts[4], # pad_n
trait_parts[5], # pad_k
trait_parts[6], # persistent
)
# Generate the kernel
builder._generate_kernel_instance(
tile_config,
trait_combo,
)
elif args.gen_all_individual:
# Generate all individual kernel files
builder._generate_all_individual(args.num_workers)
else:
parser.error(
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,307 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <fstream>
#include <iomanip>
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "gemm_multi_d_benchmark.hpp"
class GemmMultiDProfiler
{
public:
static GemmMultiDProfiler& instance(Setting setting)
{
static GemmMultiDProfiler instance{setting};
return instance;
}
// Overload for single kernel benchmarking
void benchmark(GemmMultiDProblem& gemm_multi_d_problem,
std::function<float(const ck_tile::GemmMultiDHostArgs<DsDataType::size()>&,
const ck_tile::stream_config&)> kernel_func)
{
// Create a vector with a single callable that returns both name and time
std::vector<std::function<std::tuple<std::string, float>(
ck_tile::GemmMultiDHostArgs<DsDataType::size()>&, const ck_tile::stream_config&)>>
callables;
callables.push_back([kernel_func](ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args,
const ck_tile::stream_config& stream) {
float time = kernel_func(args, stream);
return std::make_tuple(std::string(KERNEL_NAME), time);
});
benchmark(gemm_multi_d_problem, callables);
}
void benchmark(
GemmMultiDProblem& gemm_multi_d_problem,
std::vector<std::function<std::tuple<std::string, float>(
ck_tile::GemmMultiDHostArgs<DsDataType::size()>&, const ck_tile::stream_config&)>>&
callables)
{
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
const D0Layout layout_d0 = D0Layout{};
const D1Layout layout_d1 = D1Layout{};
const CLayout layout_c = CLayout{};
gemm_multi_d_problem.stride_a_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_,
gemm_multi_d_problem.k_,
gemm_multi_d_problem.stride_a_,
is_row_major(layout_a));
gemm_multi_d_problem.stride_b_ = ck_tile::get_default_stride(gemm_multi_d_problem.k_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_b_,
is_row_major(layout_b));
gemm_multi_d_problem.stride_d0_ =
ck_tile::get_default_stride(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_d0_,
is_row_major(layout_d0));
gemm_multi_d_problem.stride_d1_ =
ck_tile::get_default_stride(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_d1_,
is_row_major(layout_d1));
gemm_multi_d_problem.stride_c_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_c_,
is_row_major(layout_c));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
gemm_multi_d_problem.k_,
gemm_multi_d_problem.stride_a_,
is_row_major(layout_a)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.k_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_b_,
is_row_major(layout_b)));
ck_tile::HostTensor<D0DataType> d0_m_n(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_d0_,
is_row_major(layout_d0)));
ck_tile::HostTensor<D1DataType> d1_m_n(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_d1_,
is_row_major(layout_d1)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_c_,
is_row_major(layout_c)));
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n);
ck_tile::FillUniformDistribution<D1DataType>{-1.f, 1.f}(d1_m_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.mData.data());
b_k_n_dev_buf.ToDevice(b_k_n.mData.data());
d0_m_n_dev_buf.ToDevice(d0_m_n.mData.data());
d1_m_n_dev_buf.ToDevice(d1_m_n.mData.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
std::array<const void*, DsDataType::size()> ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(),
d1_m_n_dev_buf.GetDeviceBuffer()};
std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {
gemm_multi_d_problem.stride_d0_, gemm_multi_d_problem.stride_d1_};
ck_tile::GemmMultiDHostArgs<DsDataType::size()> gemm_multi_d_args = {
a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
ds_ptr_buf,
c_m_n_dev_buf.GetDeviceBuffer(),
gemm_multi_d_problem.split_k_,
gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.k_,
gemm_multi_d_problem.stride_a_,
gemm_multi_d_problem.stride_b_,
stridesDs,
gemm_multi_d_problem.stride_c_,
};
ck_tile::HostTensor<CDataType> c_m_n_host_result(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_c_,
is_row_major(layout_c)));
if(setting_.verify_)
{
gemm_multi_d_host_reference(
setting_.verify_, a_m_k, b_k_n, d0_m_n, d1_m_n, c_m_n_host_result);
}
for(auto& callable : callables)
{
auto kernel_run_result =
callable(gemm_multi_d_args,
ck_tile::stream_config{
nullptr, true, setting_.log_, setting_.n_warmup_, setting_.n_repeat_});
process_result(gemm_multi_d_problem,
c_m_n_dev_buf,
c_m_n_host_result,
c_m_n_dev_result,
kernel_run_result);
}
}
void process_result(const GemmMultiDProblem& gemm_multi_d_problem,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
const std::tuple<std::string, float>& kernel_run_result)
{
auto [name, avg_time] = kernel_run_result;
KernelInstance kernel_instance{name, gemm_multi_d_problem, {-1.0f, -1.0f, -1.0f}};
// compute performance metric
std::size_t flop = std::size_t(2) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_ *
gemm_multi_d_problem.k_;
std::size_t num_byte =
sizeof(ADataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.k_ +
sizeof(BDataType) * gemm_multi_d_problem.n_ * gemm_multi_d_problem.k_ +
sizeof(CDataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
// Dth Dimension Updates
ck_tile::static_for<0, DsDataType::size(), 1>{}([&](auto i) {
num_byte += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) *
gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
flop += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) *
gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
});
// update
kernel_instance.perf_result_.latency_ = avg_time;
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
if(setting_.log_ > 0 && !setting_.json_output_)
{
std::cout << kernel_instance << std::endl;
}
// verify result
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool verified_correct =
!setting_.verify_ || compare(name,
gemm_multi_d_problem.k_,
1, // Multi d currently supports only k_batch = 1
c_m_n_dev_result,
c_m_n_host_result);
if(verified_correct)
{
kernel_instances_.emplace_back(kernel_instance);
}
else
{
std::cout << "Verification failed, skip kernel: " << name << std::endl;
}
// clear tensor
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
}
KernelInstance select_best_instance(Metric metric)
{
if(kernel_instances_.empty())
throw std::runtime_error("Empty instances");
auto kernel_instance = *std::max_element(kernel_instances_.begin(),
kernel_instances_.end(),
[metric](const auto& a, const auto& b) {
return PerformanceResult::compare(
b.perf_result_, a.perf_result_, metric);
});
if(setting_.json_output_)
{
// Output clean JSON only
std::cout << kernel_instance << std::endl;
}
else
{
std::cout << "**********************************" << std::endl;
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
<< "Current kernel performance is: " << kernel_instance << std::endl;
std::cout << "**********************************" << std::endl;
}
if(!setting_.csv_filename_.empty())
{
std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app);
if(!file.is_open())
{
std::cerr << "Warning: Failed to open CSV file for writing." << std::endl;
}
else
{
if(file.tellp() == 0)
{
file << "rocm_version,device_name,"
<< "split_k,m,n,k,stride_a,stride_b,stride_c,"
<< "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c,"
<< "structured_sparsity," << "name,"
<< "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
}
const auto& problem = kernel_instance.problem_;
const auto& name = kernel_instance.name_;
const auto& perf = kernel_instance.perf_result_;
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
<< problem.split_k_ << "," << problem.m_ << "," << problem.n_ << ","
<< problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << ","
<< problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_
<< "," << problem.dtype_acc_ << "," << problem.dtype_c_ << ","
<< problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_
<< "," << name << "," << std::fixed << std::setprecision(4) << perf.latency_
<< "," << std::fixed << std::setprecision(4) << perf.tflops_ << ","
<< std::fixed << std::setprecision(4) << perf.bandwidth_ << ","
<< get_metric_name(metric) << "\n";
if(!file)
{
std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl;
}
}
}
return kernel_instance;
}
GemmMultiDProfiler(const GemmMultiDProfiler&) = delete;
GemmMultiDProfiler& operator=(const GemmMultiDProfiler&) = delete;
private:
~GemmMultiDProfiler() { kernel_instances_.clear(); }
GemmMultiDProfiler(Setting setting) : setting_(setting) {}
Setting setting_;
std::vector<KernelInstance> kernel_instances_;
};

View File

@@ -0,0 +1,302 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(GEMM_PRESHUFFLE_DATATYPE "fp16;fp8;bf16;bf8" CACHE STRING "List of datatypes for GEMM Preshuffle (semicolon-separated)")
set(GEMM_PRESHUFFLE_LAYOUT "rcr" CACHE STRING "List of layout for GEMM Preshuffle (semicolon-separated)")
set(GEMM_PRESHUFFLE_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
option(ENABLE_CCACHE_GEMM_PRESHUFFLE "Enable ccache for GEMM Preshuffle ops compilation" OFF)
# Store the directory path for use in functions
set(GEMM_PRESHUFFLE_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
# Function to create individual GEMM Preshuffle targets
function(create_individual_gemm_preshuffle_target datatype layout trait tile_config config_json)
# Use the parent scope GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL variable
if(NOT GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping individual GEMM Preshuffle target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets")
return()
endif()
# Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k
# First split by underscore to get three groups
string(REPLACE "_" ";" config_groups ${tile_config})
list(GET config_groups 0 tile_dims) # e.g., 256x256x32
list(GET config_groups 1 warp_dims) # e.g., 4x1x1
list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16
# Parse tile dimensions
string(REPLACE "x" ";" tile_parts ${tile_dims})
list(GET tile_parts 0 tile_m)
list(GET tile_parts 1 tile_n)
list(GET tile_parts 2 tile_k)
# Parse warp dimensions
string(REPLACE "x" ";" warp_parts ${warp_dims})
list(GET warp_parts 0 warp_m)
list(GET warp_parts 1 warp_n)
list(GET warp_parts 2 warp_k)
# Parse warp tile dimensions
string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims})
list(GET warp_tile_parts 0 warp_tile_m)
list(GET warp_tile_parts 1 warp_tile_n)
list(GET warp_tile_parts 2 warp_tile_k)
set(target_name "benchmark_gemm_preshuffle_${datatype}_${layout}_${trait}_${tile_config}")
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Generate the single instance header for this kernel
set(instance_header "${working_path}/gemm_preshuffle_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
# Add custom command to generate the header file at build time
add_custom_command(
OUTPUT ${instance_header}
COMMAND ${Python3_EXECUTABLE} ${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${config_json}
--gen_single
--kernel_name "gemm_preshuffle_${datatype}_${layout}_${trait}_${tile_config}"
--tile_config "${tile_config}"
--trait_combo "${trait}"
--gpu_target "${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}"
DEPENDS ${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_instance_builder.py ${config_json}
COMMENT "Generating ${instance_header}"
)
# Create the executable
add_executable(${target_name}
EXCLUDE_FROM_ALL
${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_benchmark_single.cpp
${instance_header}
)
# Set GPU architectures
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL})
# Set compile definitions
target_compile_definitions(${target_name} PRIVATE
GEMM_PRESHUFFLE_SINGLE_INSTANCE_HPP="${instance_header}"
)
# Include directories
target_include_directories(${target_name} PRIVATE
${GEMM_PRESHUFFLE_SOURCE_DIR}
${working_path}
)
# Compile options
target_compile_options(${target_name} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
-include ${instance_header}
)
# Add to collection targets
add_dependencies(benchmark_gemm_preshuffle_all ${target_name})
add_dependencies(benchmark_gemm_preshuffle_${datatype} ${target_name})
add_dependencies(benchmark_gemm_preshuffle_${layout} ${target_name})
add_dependencies(benchmark_gemm_preshuffle_${datatype}_${layout} ${target_name})
# Add to trait-specific targets
string(REPLACE "_" ";" trait_parts ${trait})
list(GET trait_parts 0 pipeline)
list(GET trait_parts 1 epilogue)
list(GET trait_parts 2 scheduler)
add_dependencies(benchmark_gemm_preshuffle_${pipeline}_pipeline ${target_name})
add_dependencies(benchmark_gemm_preshuffle_${epilogue}_epilogue ${target_name})
add_dependencies(benchmark_gemm_preshuffle_${scheduler}_scheduler ${target_name})
endfunction()
# Function to build individual GEMM Preshuffle targets
function(build_individual_gemm_preshuffle_targets datatype layout)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Choose config file
# Priority order:
# 1. Environment variable GEMM_PRESHUFFLE_CONFIG_FILE
# 2. CMake variable GEMM_PRESHUFFLE_CONFIG_FILE
# 3. Default based on layout
# Check environment variable first
if(DEFINED ENV{GEMM_PRESHUFFLE_CONFIG_FILE} AND NOT "$ENV{GEMM_PRESHUFFLE_CONFIG_FILE}" STREQUAL "")
set(config_filename "$ENV{GEMM_PRESHUFFLE_CONFIG_FILE}")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
message(VERBOSE " Using config from environment variable: ${config_filename}")
elseif(NOT "${GEMM_PRESHUFFLE_CONFIG_FILE}" STREQUAL "")
# Use CMake variable if set
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_PRESHUFFLE_CONFIG_FILE}")
message(VERBOSE " Using custom config: ${GEMM_PRESHUFFLE_CONFIG_FILE}")
else()
# Use default config for all layouts
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
message(VERBOSE " Using default config for layout ${layout}")
endif()
# Check if config file exists
if(NOT EXISTS ${json_blob})
message(FATAL_ERROR "Config file not found: ${json_blob}")
endif()
# Determine number of workers for parallel generation
if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL})
set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL})
else()
# Use processor count but limit to avoid memory issues
cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES)
math(EXPR num_workers "${num_cores}")
if(num_workers GREATER 8)
set(num_workers 8)
endif()
endif()
# Generate individual kernel files using parallel version
message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
message(VERBOSE " Working path: ${working_path}")
message(VERBOSE " Config file: ${json_blob}")
message(VERBOSE " Python executable: ${Python3_EXECUTABLE}")
message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_preshuffle_instance_builder.py")
# Create working directory first
file(MAKE_DIRECTORY ${working_path})
# First, just list the kernels (fast operation)
message(VERBOSE " Listing kernel configurations...")
message(VERBOSE " GPU Targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}")
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_preshuffle_instance_builder.py
--working_path ${working_path}
--gpu_target ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--list_kernels
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
ERROR_VARIABLE list_error
)
if(NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
endif()
# Read kernel count
if(EXISTS ${working_path}/gemm_preshuffle_kernel_count.txt)
file(READ ${working_path}/gemm_preshuffle_kernel_count.txt kernel_count)
string(STRIP "${kernel_count}" kernel_count)
message(VERBOSE " Found ${kernel_count} kernel configurations")
else()
message(FATAL_ERROR "Kernel count file not found")
endif()
# Read kernel list and create targets
if(EXISTS ${working_path}/gemm_preshuffle_kernel_list.txt)
file(STRINGS ${working_path}/gemm_preshuffle_kernel_list.txt kernel_lines)
foreach(line IN LISTS kernel_lines)
# Parse line: kernel_name|tile_config|trait_combo
string(REPLACE "|" ";" parts "${line}")
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
# Create individual target
create_individual_gemm_preshuffle_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}")
endforeach()
else()
message(FATAL_ERROR "Kernel list file not found")
endif()
endfunction()
# Main build logic - Only individual builds supported
message(VERBOSE "=== Starting Tile Engine GEMM Preshuffle Configuration ===")
message(VERBOSE "GEMM_PRESHUFFLE_DATATYPE: ${GEMM_PRESHUFFLE_DATATYPE}")
message(VERBOSE "GEMM_PRESHUFFLE_LAYOUT: ${GEMM_PRESHUFFLE_LAYOUT}")
message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx90a, gfx942, and gfx950
set(GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
list(APPEND GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL ${target})
message(VERBOSE " Adding GPU target: ${target}")
endif()
endforeach()
# Skip build if no matching targets found
if(NOT GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(VERBOSE "Building individual GEMM Preshuffle targets for GPU targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}")
# Enable parallel compilation optimizations
# Set up job pools for better parallel compilation control
set_property(GLOBAL PROPERTY JOB_POOLS
compile_heavy=4 # Limit heavy compilations to prevent OOM
compile_normal=16 # Allow more parallel normal compilations
)
# Enable compiler cache if available and explicitly requested
# Disabled by default due to permission issues in CI environments
if(ENABLE_CCACHE_GEMM_PRESHUFFLE)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
message(VERBOSE "Using ccache for faster compilation")
else()
message(WARNING "ccache requested but not found")
endif()
else()
message(VERBOSE "ccache disabled for GEMM Preshuffle ops (use -DENABLE_CCACHE_GEMM_PRESHUFFLE=ON to enable)")
endif()
# Create master collection targets
add_custom_target(benchmark_gemm_preshuffle_all)
# Create datatype collection targets
foreach(dt IN LISTS GEMM_PRESHUFFLE_DATATYPE)
add_custom_target(benchmark_gemm_preshuffle_${dt})
endforeach()
# Create layout collection targets
foreach(l IN LISTS GEMM_PRESHUFFLE_LAYOUT)
add_custom_target(benchmark_gemm_preshuffle_${l})
endforeach()
# Create combined collection targets
foreach(dt IN LISTS GEMM_PRESHUFFLE_DATATYPE)
foreach(l IN LISTS GEMM_PRESHUFFLE_LAYOUT)
add_custom_target(benchmark_gemm_preshuffle_${dt}_${l})
endforeach()
endforeach()
# Create trait-based collection targets
# These are common trait components used across all GEMM Preshuffle kernels
set(GEMM_PRESHUFFLE_PIPELINES "preshufflev2")
set(GEMM_PRESHUFFLE_EPILOGUES "default;cshuffle")
set(GEMM_PRESHUFFLE_SCHEDULERS "default")
foreach(pipeline IN LISTS GEMM_PRESHUFFLE_PIPELINES)
add_custom_target(benchmark_gemm_preshuffle_${pipeline}_pipeline)
endforeach()
foreach(epilogue IN LISTS GEMM_PRESHUFFLE_EPILOGUES)
add_custom_target(benchmark_gemm_preshuffle_${epilogue}_epilogue)
endforeach()
foreach(scheduler IN LISTS GEMM_PRESHUFFLE_SCHEDULERS)
add_custom_target(benchmark_gemm_preshuffle_${scheduler}_scheduler)
endforeach()
# Build individual targets for each datatype/layout combination
foreach(dt IN LISTS GEMM_PRESHUFFLE_DATATYPE)
foreach(l IN LISTS GEMM_PRESHUFFLE_LAYOUT)
build_individual_gemm_preshuffle_targets(${dt} ${l})
endforeach()
endforeach()
endif()

View File

@@ -0,0 +1,102 @@
{
"tile_config": {
"tile_m": {
"max": 256,
"min": 64,
"step": 64
},
"tile_n": {
"max": 256,
"min": 64,
"step": 64
},
"tile_k": {
"max": 256,
"min": 64,
"step": 64
},
"warp_m": {
"values": [
4,
2,
1
]
},
"warp_n": {
"values": [
4,
2,
1
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
4,
16,
32
]
},
"warp_tile_n": {
"values": [
16,
32,
64
]
},
"warp_tile_k": {
"values": [
8,
16,
32,
64,
128
]
}
},
"trait_config": {
"pipeline": {
"values": [
"preshufflev2"
]
},
"scheduler": {
"values": [
"default"
]
},
"epilogue": {
"values": [
"default",
"cshuffle"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
},
"persistent": {
"values": [
true,
false
]
}
},
"k_block_per_cu": 1,
"permute_n": true
}

View File

@@ -0,0 +1,89 @@
{
"tile_config": {
"tile_m": {
"values": [
64
]
},
"tile_n": {
"values": [
64
]
},
"tile_k": {
"values": [
192
]
},
"warp_m": {
"values": [
2
]
},
"warp_n": {
"values": [
2
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
16
]
},
"warp_tile_n": {
"values": [
16
]
},
"warp_tile_k": {
"values": [
32
]
}
},
"trait_config": {
"pipeline": {
"values": [
"preshufflev2"
]
},
"scheduler": {
"values": [
"default"
]
},
"epilogue": {
"values": [
"default",
"cshuffle"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
},
"persistent": {
"values": [
true
]
}
},
"k_block_per_cu": 1,
"permute_n": false
}

View File

@@ -0,0 +1,236 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "gemm_preshuffle_common.hpp"
//[TODO] Move parts of this File to commons
enum class Metric
{
LATENCY = 0,
TFLOPS = 1,
BANDWIDTH = 2
};
inline constexpr auto get_metric_name(Metric m)
{
switch(m)
{
case Metric::LATENCY: return "latency";
case Metric::TFLOPS: return "tflops";
case Metric::BANDWIDTH: return "bandwidth";
default: throw std::invalid_argument("Unsupported metric type");
}
}
struct KernelConfig
{
std::tuple<int, int, int> tile_dims;
std::tuple<int, int, int> warp_dims;
std::tuple<int, int, int> warp_tile_dims;
bool permuteN;
};
struct GemmProblem
{
int split_k_;
int m_, n_, k_;
int stride_a_, stride_b_, stride_c_;
std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_;
std::string layout_a_, layout_b_, layout_c_;
bool structured_sparsity_;
friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem)
{
os << "{\n"
<< " \"split_k\":" << problem.split_k_ << ",\n"
<< " \"m\":" << problem.m_ << ",\n"
<< " \"n\":" << problem.n_ << ",\n"
<< " \"k\":" << problem.k_ << ",\n"
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
<< " \"stride_c\":" << problem.stride_c_ << ",\n"
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
<< " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n"
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
<< " \"layout_c\":\"" << problem.layout_c_ << "\",\n"
<< " \"structured_sparsity\":" << (problem.structured_sparsity_ ? "true" : "false")
<< "\n"
<< "}";
return os;
}
};
struct PerformanceResult
{
double latency_;
double tflops_;
double bandwidth_;
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
{
switch(m)
{
case Metric::LATENCY: return a.latency_ < b.latency_;
case Metric::TFLOPS: return a.tflops_ > b.tflops_;
case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_;
default: throw std::invalid_argument("Unsupported metric type");
}
}
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
{
os << "{\n"
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
<< ",\n"
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
<< "}";
return os;
}
};
struct KernelInstance
{
std::string name_;
GemmProblem problem_;
PerformanceResult perf_result_;
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
{
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
}
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
{
os << "{\n"
<< " \"name\": \"" << obj.name_ << "\",\n"
<< " \"problem\": " << obj.problem_ << ",\n"
<< " \"perf_result\": " << obj.perf_result_ << "\n"
<< "}";
return os;
}
};
struct Setting
{
int n_warmup_;
int n_repeat_;
bool is_gpu_timer_;
int verify_;
int init_method_;
bool log_;
std::string csv_filename_;
bool flush_cache_;
int rotating_count_;
bool json_output_;
};
inline std::string get_rocm_version()
{
std::ifstream version_file("/opt/rocm/.info/version");
if(version_file.is_open())
{
std::string version;
std::getline(version_file, version);
return version;
}
return "Unknown";
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
/// @brief Function to compare the results of the device and host computations
bool compare(std::string instanceName,
ck_tile::index_t K,
ck_tile::index_t kbatch,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
ck_tile::HostTensor<CDataType>& c_m_n_ref)
{
const float max_accumulated_value =
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "For " << instanceName << " Relative error threshold is "
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
return pass;
}
/// @brief Function to get the kernel output with reference implementation on CPU/GPU
void gemm_host_reference(int verify,
ck_tile::HostTensor<ADataType>& a_m_k,
ck_tile::HostTensor<BDataType>& b_k_n,
ck_tile::HostTensor<CDataType>& c_m_n_ref,
ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C)
{
if(verify == 1)
{
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_ref);
}
else if(verify == 2)
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes());
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data());
}
}

View File

@@ -0,0 +1,683 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import sys
import json
import subprocess
import argparse
import csv
import time
from pathlib import Path
from typing import List, Dict, Tuple, Optional
class GemmPreshuffleBenchmark:
def __init__(self, build_dir: str, verbose: bool = False):
self.build_dir = Path(build_dir)
self.verbose = verbose
self.results = []
def discover_kernels(self) -> List[Path]:
"""Find all benchmark_gemm_preshuffle* executables in the build directory"""
bin_dir = self.build_dir / "bin"
if not bin_dir.exists():
print(f"Error: Binary directory {bin_dir} does not exist")
return []
kernels = list(bin_dir.glob("benchmark_gemm_preshuffle*"))
if self.verbose:
print(f"Found {len(kernels)} kernel executables")
for k in kernels:
print(f" - {k.name}")
return kernels
def extract_kernel_info(self, kernel_path: Path) -> Dict[str, str]:
"""Extract comprehensive kernel information from filename"""
name = kernel_path.stem
# Initialize with basic info
info = {
"executable": str(kernel_path),
"name": name,
"data_type": "unknown",
"layout": "unknown",
"pipeline": "unknown",
"scheduler": "unknown",
"epilogue": "unknown",
}
# Parse the kernel name pattern:
# benchmark_gemm_preshuffle_fp16_rcr_mem_default_intrawave_False_False_False_False_False_256x256x32_2x2x1_4x64x16
parts = name.split("_")
if len(parts) >= 4:
# Extract data type (4rd part after benchmark_gemm_preshuffle_)
info["data_type"] = parts[3] if len(parts) > 2 else "unknown"
# Extract layout (5th part)
info["layout"] = parts[4] if len(parts) > 3 else "unknown"
# Extract pipeline (6th part)
info["pipeline"] = parts[5] if len(parts) > 4 else "unknown"
# Extract epilogue (7th part)
info["epilogue"] = parts[6] if len(parts) > 5 else "unknown"
# Extract scheduler (8th part)
info["scheduler"] = parts[7] if len(parts) > 6 else "unknown"
# Extract detailed configuration from the end of the name
config_info = self.parse_detailed_config(name)
info.update(config_info)
# Generate config ID
info["config_id"] = self.generate_config_id(info)
return info
def parse_detailed_config(self, kernel_name: str) -> Dict:
"""Parse detailed configuration from kernel name"""
config = {
"tile_sizes": {"tile_m": 0, "tile_n": 0, "tile_k": 0},
"warp_config": {"warp_m": 0, "warp_n": 0, "warp_k": 0},
"warp_tile": {"warp_tile_m": 0, "warp_tile_n": 0, "warp_tile_k": 0},
"optimization_flags": {
"pad_m": False,
"pad_n": False,
"pad_k": False,
"persistent": False,
},
}
# Split by underscore and look for patterns
parts = kernel_name.split("_")
# Look for boolean flags (sequence of True/False values)
bool_sequence = []
for i, part in enumerate(parts):
if part in ["True", "False"]:
bool_sequence.append(part == "True")
# Continue collecting consecutive boolean values
j = i + 1
while j < len(parts) and parts[j] in ["True", "False"]:
bool_sequence.append(parts[j] == "True")
j += 1
break
# Assign boolean flags if we found them
# Order: pad_m, pad_n, pad_k, persistent (4 flags total)
if len(bool_sequence) >= 4:
config["optimization_flags"]["pad_m"] = bool_sequence[0]
config["optimization_flags"]["pad_n"] = bool_sequence[1]
config["optimization_flags"]["pad_k"] = bool_sequence[2]
config["optimization_flags"]["persistent"] = bool_sequence[3]
# Look for tile size patterns (e.g., 256x256x32_2x2x1_4x64x16)
# The pattern is: tile_sizes_warp_config_warp_tile
dimension_groups = []
for part in parts:
if "x" in part and len(part.split("x")) == 3:
try:
dims = [int(x) for x in part.split("x")]
if all(d > 0 for d in dims):
dimension_groups.append(dims)
except ValueError:
continue
# Assign dimensions based on order and magnitude
if len(dimension_groups) >= 3:
# Sort by magnitude to identify: largest=tile_sizes, smallest=warp_config, middle=warp_tile
sorted_groups = sorted(dimension_groups, key=lambda x: max(x), reverse=True)
# Largest dimensions = tile sizes
config["tile_sizes"]["tile_m"] = sorted_groups[0][0]
config["tile_sizes"]["tile_n"] = sorted_groups[0][1]
config["tile_sizes"]["tile_k"] = sorted_groups[0][2]
# Smallest dimensions = warp config
config["warp_config"]["warp_m"] = sorted_groups[2][0]
config["warp_config"]["warp_n"] = sorted_groups[2][1]
config["warp_config"]["warp_k"] = sorted_groups[2][2]
# Middle dimensions = warp tile
config["warp_tile"]["warp_tile_m"] = sorted_groups[1][0]
config["warp_tile"]["warp_tile_n"] = sorted_groups[1][1]
config["warp_tile"]["warp_tile_k"] = sorted_groups[1][2]
elif len(dimension_groups) == 2:
# If only 2 groups, assign based on magnitude
sorted_groups = sorted(dimension_groups, key=lambda x: max(x), reverse=True)
# Larger = tile sizes
config["tile_sizes"]["tile_m"] = sorted_groups[0][0]
config["tile_sizes"]["tile_n"] = sorted_groups[0][1]
config["tile_sizes"]["tile_k"] = sorted_groups[0][2]
# Smaller = warp config
config["warp_config"]["warp_m"] = sorted_groups[1][0]
config["warp_config"]["warp_n"] = sorted_groups[1][1]
config["warp_config"]["warp_k"] = sorted_groups[1][2]
elif len(dimension_groups) == 1:
# Only one group - assume it's tile sizes
config["tile_sizes"]["tile_m"] = dimension_groups[0][0]
config["tile_sizes"]["tile_n"] = dimension_groups[0][1]
config["tile_sizes"]["tile_k"] = dimension_groups[0][2]
return config
def generate_config_id(self, info: Dict) -> str:
"""Generate a compact config ID from kernel info"""
# Create a compact identifier
parts = [
info.get("data_type", "unk"),
info.get("layout", "unk"),
info.get("pipeline", "unk"),
info.get("scheduler", "unk"),
]
# Add tile configuration if available
tile_sizes = info.get("tile_sizes", {})
if tile_sizes.get("tile_m", 0) > 0:
tile_str = (
f"{tile_sizes['tile_m']}x{tile_sizes['tile_n']}x{tile_sizes['tile_k']}"
)
parts.append(tile_str)
# Add warp config if available
warp_config = info.get("warp_config", {})
if warp_config.get("warp_m", 0) > 0:
warp_str = f"w{warp_config['warp_m']}x{warp_config['warp_n']}x{warp_config['warp_k']}"
parts.append(warp_str)
# Add warp tile if available
warp_tile = info.get("warp_tile", {})
if warp_tile.get("warp_tile_m", 0) > 0:
warp_tile_str = f"wt{warp_tile['warp_tile_m']}x{warp_tile['warp_tile_n']}x{warp_tile['warp_tile_k']}"
parts.append(warp_tile_str)
return "_".join(parts)
def run_kernel(self, kernel_path: Path, params: Dict[str, str]) -> Optional[Dict]:
"""Run a single kernel with given parameters and save output to individual JSON file"""
# Create results directory
results_dir = self.build_dir / "results"
results_dir.mkdir(exist_ok=True)
# Generate unique JSON filename for this kernel
json_file = results_dir / f"{kernel_path.stem}.json"
cmd = [str(kernel_path)]
# Add parameters
for key, value in params.items():
cmd.append(f"-{key}={value}")
# Add JSON output flag for clean JSON output
cmd.append("-json_output=true")
if self.verbose:
print(f"Running: {' '.join(cmd)}")
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
if result.returncode != 0:
print(f"Error running {kernel_path.name}: {result.stderr}")
return None
# Save raw output to individual JSON file
output = result.stdout.strip()
if output:
with open(json_file, "w") as f:
f.write(output)
# Parse the JSON file
return self.parse_json_file(json_file)
else:
print(f"No output from {kernel_path.name}")
return None
except subprocess.TimeoutExpired:
print(f"Timeout running {kernel_path.name}")
return None
except Exception as e:
print(f"Error running {kernel_path.name}: {e}")
return None
def parse_json_file(self, json_file: Path) -> Optional[Dict]:
"""Parse JSON data from individual kernel output file"""
try:
with open(json_file, "r") as f:
content = f.read().strip()
# Parse the JSON directly since executables produce clean JSON
data = json.loads(content)
# Return the complete JSON data as-is, just add some convenience fields
result = data.copy()
if "perf_result" in data:
perf = data["perf_result"]
# Add convenience fields for backward compatibility
result["time_ms"] = perf.get("latency(ms)", 0)
result["tflops"] = perf.get("tflops(TFlops)", 0)
result["bandwidth_gb_s"] = perf.get("bandwidth(GB/s)", 0)
return result
except json.JSONDecodeError as e:
if self.verbose:
print(f"Failed to parse JSON from {json_file}: {e}")
return None
except Exception as e:
if self.verbose:
print(f"Error reading JSON file {json_file}: {e}")
return None
def benchmark_problem_size(
self,
kernels: List[Path],
m: int,
n: int,
k: int,
split_k: int = 1,
verify: int = 0,
warmup: int = 50,
repeat: int = 100,
flush_cache: bool = True,
rotating_count: int = 1000,
) -> List[Dict]:
"""Benchmark all kernels for a specific problem size"""
results = []
params = {
"m": m,
"n": n,
"k": k,
"split_k": split_k,
"verify": verify,
"warmup": warmup,
"repeat": repeat,
"flush_cache": str(flush_cache).lower(),
"rotating_count": rotating_count,
}
print(f"\nBenchmarking M={m}, N={n}, K={k}, split_k={split_k}")
for kernel_path in kernels:
kernel_info = self.extract_kernel_info(kernel_path)
result = self.run_kernel(kernel_path, params)
if result:
# Create new structured result format
structured_result = {
"name": kernel_info["name"], # Add name field for compatibility
"config_id": kernel_info["config_id"],
"problem": result.get("problem", {}),
"perf_result": result.get("perf_result", {}),
"config": {
"data_type": kernel_info["data_type"],
"layout": kernel_info["layout"],
"pipeline": kernel_info["pipeline"],
"scheduler": kernel_info["scheduler"],
"epilogue": kernel_info["epilogue"],
"tile_sizes": kernel_info.get("tile_sizes", {}),
"warp_config": kernel_info.get("warp_config", {}),
"warp_tile": kernel_info.get("warp_tile", {}),
"optimization_flags": kernel_info.get("optimization_flags", {}),
},
"executable": kernel_info["executable"],
# Keep backward compatibility fields
"time_ms": result.get("time_ms", 0),
"tflops": result.get("tflops", 0),
"bandwidth_gb_s": result.get("bandwidth_gb_s", 0),
}
results.append(structured_result)
if self.verbose:
print(
f" {kernel_info['config_id']}: {structured_result['tflops']:.2f} TFLOPS, {structured_result['bandwidth_gb_s']:.2f} GB/s, {structured_result['time_ms']:.2f}ms"
)
return results
def find_best_kernel(
self, results: List[Dict], metric: str = "tflops"
) -> Optional[Dict]:
"""Find the best performing kernel based on metric"""
if not results:
return None
if metric == "tflops":
return max(results, key=lambda x: x.get("tflops", 0))
elif metric == "time_ms":
return min(results, key=lambda x: x.get("time_ms", float("inf")))
elif metric == "bandwidth_gb_s":
return max(results, key=lambda x: x.get("bandwidth_gb_s", 0))
else:
raise ValueError(f"Unknown metric: {metric}")
def benchmark_sweep(
self,
problem_sizes: List[Tuple[int, int, int]],
split_k_values: List[int] = [1],
verify: bool = False,
warmup: int = 50,
repeat: int = 100,
flush_cache: bool = True,
rotating_count: int = 1000,
) -> Dict:
"""Run comprehensive benchmark sweep"""
kernels = self.discover_kernels()
if not kernels:
print("No kernels found!")
return {}
all_results = []
best_kernels = {}
for m, n, k in problem_sizes:
for split_k in split_k_values:
results = self.benchmark_problem_size(
kernels,
m,
n,
k,
split_k,
verify=2 if verify else 0,
warmup=warmup,
repeat=repeat,
flush_cache=flush_cache,
rotating_count=rotating_count,
)
all_results.extend(results)
# Find best kernel for this configuration
best = self.find_best_kernel(results)
if best:
key = f"m{m}_n{n}_k{k}_splitk{split_k}"
best_kernels[key] = best
print(
f"Best for {key}: {best['name']} ({best['tflops']:.2f} TFLOPS, {best['bandwidth_gb_s']:.2f} GB/s, {best['time_ms']:.2f}ms)"
)
self.results = all_results
return best_kernels
def export_csv(self, filename: str):
"""Export all results to CSV"""
if not self.results:
print("No results to export")
return
# Get all unique keys from results
all_keys = set()
for result in self.results:
all_keys.update(result.keys())
# Sort keys for consistent output
fieldnames = sorted(all_keys)
with open(filename, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(self.results)
print(f"Results exported to {filename}")
def export_best_kernels(self, best_kernels: Dict, filename: str):
"""Export best kernel selections to file"""
with open(filename, "w") as f:
f.write("# Best kernel selections\n")
f.write(
"# Format: problem_size -> kernel_name (TFLOPS, bandwidth, latency)\n\n"
)
for key, kernel in sorted(best_kernels.items()):
f.write(
f"{key}: {kernel['name']} ({kernel['tflops']:.2f} TFLOPS, {kernel['bandwidth_gb_s']:.2f} GB/s, {kernel['time_ms']:.2f}ms)\n"
)
print(f"Best kernels exported to {filename}")
def export_json(self, filename: str, best_kernels: Dict = None):
"""Export all results and best kernels to JSON with comprehensive metadata"""
from datetime import datetime
# Calculate comprehensive summary statistics for all metrics
successful_results = [r for r in self.results if r.get("tflops", 0) > 0]
tflops_values = [r.get("tflops", 0) for r in successful_results]
bandwidth_values = [r.get("bandwidth_gb_s", 0) for r in successful_results]
latency_values = [
r.get("time_ms", 0) for r in successful_results if r.get("time_ms", 0) > 0
]
# Performance breakdown by kernel type
pipeline_stats = {}
scheduler_stats = {}
data_type_stats = {}
for result in successful_results:
# Get config info from the new structure
config = result.get("config", {})
# Pipeline statistics
pipeline = config.get("pipeline", "unknown")
if pipeline not in pipeline_stats:
pipeline_stats[pipeline] = {
"count": 0,
"avg_tflops": 0,
"best_tflops": 0,
}
pipeline_stats[pipeline]["count"] += 1
pipeline_stats[pipeline]["best_tflops"] = max(
pipeline_stats[pipeline]["best_tflops"], result.get("tflops", 0)
)
# Scheduler statistics
scheduler = config.get("scheduler", "unknown")
if scheduler not in scheduler_stats:
scheduler_stats[scheduler] = {
"count": 0,
"avg_tflops": 0,
"best_tflops": 0,
}
scheduler_stats[scheduler]["count"] += 1
scheduler_stats[scheduler]["best_tflops"] = max(
scheduler_stats[scheduler]["best_tflops"], result.get("tflops", 0)
)
# Data type statistics
data_type = config.get("data_type", "unknown")
if data_type not in data_type_stats:
data_type_stats[data_type] = {
"count": 0,
"avg_tflops": 0,
"best_tflops": 0,
}
data_type_stats[data_type]["count"] += 1
data_type_stats[data_type]["best_tflops"] = max(
data_type_stats[data_type]["best_tflops"], result.get("tflops", 0)
)
# Calculate averages for breakdown stats
for stats_dict, field_name in [
(pipeline_stats, "pipeline"),
(scheduler_stats, "scheduler"),
(data_type_stats, "data_type"),
]:
for key in stats_dict:
relevant_results = [
r
for r in successful_results
if r.get("config", {}).get(field_name, "unknown") == key
]
if relevant_results:
stats_dict[key]["avg_tflops"] = sum(
r.get("tflops", 0) for r in relevant_results
) / len(relevant_results)
output_data = {
"benchmark_metadata": {
"timestamp": datetime.now().isoformat(),
"total_kernels_tested": len(self.results),
"unique_kernels": len(
set(r.get("name", "unknown") for r in self.results)
),
"successful_runs": len(successful_results),
"failed_runs": len(self.results) - len(successful_results),
},
"performance_summary": {
"tflops_stats": {
"best": max(tflops_values, default=0),
"average": sum(tflops_values) / len(tflops_values)
if tflops_values
else 0,
"min": min(tflops_values, default=0),
"median": sorted(tflops_values)[len(tflops_values) // 2]
if tflops_values
else 0,
},
"bandwidth_stats": {
"best_gb_s": max(bandwidth_values, default=0),
"average_gb_s": sum(bandwidth_values) / len(bandwidth_values)
if bandwidth_values
else 0,
"min_gb_s": min(bandwidth_values, default=0),
"median_gb_s": sorted(bandwidth_values)[len(bandwidth_values) // 2]
if bandwidth_values
else 0,
},
"latency_stats": {
"best_ms": min(latency_values, default=0),
"average_ms": sum(latency_values) / len(latency_values)
if latency_values
else 0,
"max_ms": max(latency_values, default=0),
"median_ms": sorted(latency_values)[len(latency_values) // 2]
if latency_values
else 0,
},
"kernel_type_breakdown": {
"by_pipeline": pipeline_stats,
"by_scheduler": scheduler_stats,
"by_data_type": data_type_stats,
},
"total_problem_configurations": len(best_kernels)
if best_kernels
else 0,
},
"kernel_results": self.results,
"best_kernels_by_problem": best_kernels or {},
}
with open(filename, "w") as f:
json.dump(output_data, f, indent=2)
print(f"JSON results exported to {filename}")
print(f" - Total kernels: {len(self.results)}")
print(f" - Successful runs: {len(successful_results)}")
print(f" - Best TFLOPS: {max(tflops_values, default=0):.2f}")
print(f" - Best bandwidth: {max(bandwidth_values, default=0):.2f} GB/s")
print(f" - Best latency: {min(latency_values, default=0):.2f}ms")
def main():
parser = argparse.ArgumentParser(
description="GEMM Preshuffle Kernel Benchmarking Tool"
)
parser.add_argument(
"build_dir", help="Build directory containing kernel executables"
)
parser.add_argument(
"--problem-sizes",
nargs="+",
default=["1024,1024,1024", "2048,2048,2048", "4096,4096,4096"],
help="Problem sizes as M,N,K tuples",
)
parser.add_argument(
"--split-k", nargs="+", type=int, default=[1], help="Split-K values to test"
)
parser.add_argument("--verify", action="store_true", help="Enable verification")
parser.add_argument(
"--csv",
default="gemm_preshuffle_benchmark_results.csv",
help="CSV output filename",
)
parser.add_argument(
"--best", default="best_kernels.txt", help="Best kernels output filename"
)
parser.add_argument("--verbose", action="store_true", help="Verbose output")
parser.add_argument(
"--warmup",
type=int,
default=50,
help="Number of warmup iterations (default: 50)",
)
parser.add_argument(
"--repeat",
type=int,
default=100,
help="Number of benchmark iterations (default: 100)",
)
parser.add_argument(
"--flush-cache",
action="store_true",
default=True,
help="Enable cache flushing (default: True)",
)
parser.add_argument(
"--rotating-count",
type=int,
default=1000,
help="Number of iterations to rotate cache (default: 1000)",
)
parser.add_argument("--json", help="JSON output filename (optional)")
args = parser.parse_args()
# Parse problem sizes
problem_sizes = []
for size_str in args.problem_sizes:
try:
m, n, k = map(int, size_str.split(","))
problem_sizes.append((m, n, k))
except ValueError:
print(f"Invalid problem size: {size_str}")
return 1
# Create benchmark instance
benchmark = GemmPreshuffleBenchmark(args.build_dir, verbose=args.verbose)
# Run benchmark sweep
print("Starting GEMM Preshuffle kernel benchmark sweep...")
start_time = time.time()
best_kernels = benchmark.benchmark_sweep(
problem_sizes=problem_sizes,
split_k_values=args.split_k,
verify=args.verify,
warmup=args.warmup,
repeat=args.repeat,
flush_cache=args.flush_cache,
rotating_count=args.rotating_count,
)
elapsed_time = time.time() - start_time
print(f"\nBenchmark completed in {elapsed_time:.2f} seconds")
# Export results
benchmark.export_csv(args.csv)
benchmark.export_best_kernels(best_kernels, args.best)
# Export JSON if requested
if args.json:
benchmark.export_json(args.json, best_kernels)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,171 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include <functional>
#include <tuple>
#include <exception>
#include <sstream>
#include <vector>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "gemm_preshuffle_profiler.hpp"
#include "gemm_preshuffle_common.hpp"
// The kernel header is included via the compile command line with -include flag
// It defines SelectedKernel struct and KERNEL_NAME
// DataTypeTraits are now defined in gemm_common.hpp
// Create argument parser
inline auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.")
.insert("n", "4096", "The value for n dimension. Default is 4096.")
.insert("k", "2048", "The value for k dimension. Default is 2048.")
.insert("stride_a", "0", "The stride value for tensor A. Default is 0.")
.insert("stride_b", "0", "The stride value for tensor B. Default is 0.")
.insert("stride_c", "0", "The stride value for tensor C. Default is 0.")
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
.insert("verify",
"2",
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
"for validation on GPU. Default is 0, no validation.")
.insert("log",
"false",
"Whether output kernel instance information or not. Possible values are true or "
"false. Default is false")
.insert(
"warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.")
.insert(
"repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.")
.insert("timer",
"true",
"Whether if the timer is gpu timer or not. Possible values are false or true. "
"Default is true.")
.insert("init",
"0",
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
"for constant(1). Default is 0, random.")
.insert("flush_cache",
"true",
"To flush cache, possible values are true or false. "
"Default is false.")
.insert("rotating_count", "1000", "number of iterations to rotate the cache. default is 5.")
.insert("metric",
"0",
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
"tflops, or 2 for bandwidth. Default is 0, latency.")
.insert("csv_filename",
"",
"The filename of benchmark result. Default is empty (no CSV output).")
.insert("structured_sparsity",
"false",
"Whether use sparsity kernel or not. Possible values are true or false. Default is "
"false")
.insert("json_output",
"false",
"Whether to output results in JSON format only. Possible values are true or false. "
"Default is "
"false");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
void benchmark_single(const ck_tile::ArgParser& arg_parser)
{
// Use DataTypeTraits to get the actual type names from the generated header
// The generated header defines ADataType, BDataType, AccDataType, CDataType
std::string dtype_a = DataTypeTraits<ADataType>::name;
std::string dtype_b = DataTypeTraits<BDataType>::name;
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
std::string dtype_c = DataTypeTraits<CDataType>::name;
// Layout names from the layout types
std::string layout_a = ALayout::name;
std::string layout_b = BLayout::name;
std::string layout_c = CLayout::name;
// Create GemmProblem struct
GemmProblem gemm_problem{arg_parser.get_int("split_k"),
arg_parser.get_int("m"),
arg_parser.get_int("n"),
arg_parser.get_int("k"),
arg_parser.get_int("stride_a"),
arg_parser.get_int("stride_b"),
arg_parser.get_int("stride_c"),
dtype_a,
dtype_b,
dtype_acc,
dtype_c,
layout_a,
layout_b,
layout_c,
arg_parser.get_bool("structured_sparsity")};
// Create Setting struct
Setting setting{arg_parser.get_int("warmup"),
arg_parser.get_int("repeat"),
arg_parser.get_bool("timer"),
arg_parser.get_int("verify"),
arg_parser.get_int("init"),
arg_parser.get_bool("log"),
arg_parser.get_str("csv_filename"),
arg_parser.get_bool("flush_cache"),
arg_parser.get_int("rotating_count"),
arg_parser.get_bool("json_output")};
// Get the profiler instance
auto& profiler = GemmProfiler::instance(setting);
try
{
// Create a lambda that wraps the kernel launch
std::tuple<int, int, int> warp_tile_dims = std::make_tuple(
SelectedKernel::WarpTileM, SelectedKernel::WarpTileN, SelectedKernel::WarpTileK);
std::tuple<int, int, int> tile_dims =
std::make_tuple(SelectedKernel::TileM, SelectedKernel::TileN, SelectedKernel::TileK);
std::tuple<int, int, int> warp_dims = std::make_tuple(SelectedKernel::WarpPerBlock_M,
SelectedKernel::WarpPerBlock_N,
SelectedKernel::WarpPerBlock_K);
bool permuteN = SelectedKernel::PermuteN;
KernelConfig config{tile_dims, warp_dims, warp_tile_dims, permuteN};
auto kernel_func = [](const ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& stream) {
return SelectedKernel::launch(args, stream);
};
// Benchmark the kernel
profiler.benchmark(gemm_problem, kernel_func, config);
// Select best instance based on metric
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
}
catch(const std::exception& e)
{
std::cerr << "Benchmark failed: " << e.what() << std::endl;
}
}
int main(int argc, char* argv[])
{
try
{
auto [result, parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
benchmark_single(parser);
return 0;
}
catch(const std::exception& e)
{
std::cerr << "Error: " << e.what() << "\n";
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,181 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
//[TODO] This can be moved to commons
// DataTypeTraits for all supported types
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <>
struct DataTypeTraits<ck_tile::int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
// Helper function to determine if a layout is row-major
template <typename Layout>
constexpr auto is_row_major(Layout)
{
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// Structure to hold kernel traits for dispatcher
struct KernelTraits
{
std::string pipeline; // preshufflev2
std::string scheduler; // intrawave, interwave, default
std::string epilogue; // cshuffle, default
bool pad_m;
bool pad_n;
bool pad_k;
bool persistent;
// Constructor with defaults
KernelTraits()
: pipeline("preshufflev2"),
scheduler("default"),
epilogue("default"),
pad_m(false),
pad_n(false),
pad_k(false),
persistent(false)
{
}
};
// Helper to extract traits from kernel name
inline KernelTraits extract_traits_from_name(const std::string& kernel_name)
{
KernelTraits traits;
// Extract pipeline
if(kernel_name.find("preshufflev2") != std::string::npos)
{
traits.pipeline = "preshufflev2";
}
// Extract scheduler
if(kernel_name.find("interwave") != std::string::npos)
{
traits.scheduler = "interwave";
}
else if(kernel_name.find("intrawave") != std::string::npos)
{
traits.scheduler = "intrawave";
}
else
{
traits.scheduler = "default";
}
// Extract epilogue
if(kernel_name.find("default") != std::string::npos &&
kernel_name.find("default_") == std::string::npos)
{
traits.epilogue = "default";
}
else
{
traits.epilogue = "cshuffle";
}
// Padding flags would need to be extracted from the kernel configuration
// For now, we'll leave them as false
return traits;
}
template <typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int divisor = N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view(
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile,
ck_tile::index_t N_Tile,
ck_tile::index_t N_Warp)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int divisor = N_Warp_Tile == 32 ? 2 : 4;
int NRepeat = N_Tile / N_Warp_Tile / N_Warp;
ck_tile::HostTensor<T> t_view({n_ / N_Tile,
N_Warp,
N_Warp_Tile,
NRepeat,
k_ / K_Warp_Tile,
divisor,
K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}

View File

@@ -0,0 +1,300 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import os
import argparse
import importlib.util
import multiprocessing
import concurrent.futures
def _import_gemm_kernel_builder():
"""Import validation utilities from commons directory."""
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
# Load the module dynamically
spec = importlib.util.spec_from_file_location(
"gemm_instance_builder",
os.path.join(parent_dir, "gemm_instance_builder.py"),
)
gemm_builder_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(gemm_builder_module)
return gemm_builder_module.GemmKernelBuilder
GemmKernelBuilder = _import_gemm_kernel_builder()
class GemmPreshuffleKernelBuilder(GemmKernelBuilder):
def __init__(
self,
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
config_json=None,
):
super().__init__(
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
)
def _generate_all_individual(self, num_workers=None):
"""Generate individual kernel files for separate compilation with parallel processing"""
if num_workers is None:
num_workers = min(
multiprocessing.cpu_count(), 8
) # Limit to avoid memory issues
tile_configs = self._get_tile_configs()
trait_combos = self._generate_trait_combinations()
# Prepare work items for parallel processing
work_items = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
work_items.append(
(
tile_config,
trait_combo,
self.kernel_name_prefix,
self.working_path,
self.gpu_target,
self.datatype,
self.layout,
self.config_json,
)
)
print(
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
)
print(f" Tile configs: {len(tile_configs)}")
print(f" Trait combinations: {len(trait_combos)}")
print(f" Total kernels: {len(work_items)}")
# Show first few work items for debugging
if work_items:
print(" First work item example:")
tile_config, trait_combo = work_items[0][:2]
print(f" Tile config: {tile_config}")
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
# Process work items in parallel
kernel_list = []
completed = 0
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
# Submit all work items
print(f" Submitting {len(work_items)} tasks to executor...")
future_to_item = {
executor.submit(_generate_single_kernel_individual, item): item
for item in work_items
}
print(" All tasks submitted, waiting for completion...")
# Collect results with progress reporting
for future in concurrent.futures.as_completed(future_to_item):
completed += 1
if completed % 100 == 0 or completed == len(work_items):
print(
f" Progress: {completed}/{len(work_items)} kernels generated"
)
try:
result = future.result()
if result:
kernel_list.append(result)
except Exception as exc:
item = future_to_item[future]
print(f"Kernel generation failed for {item}: {exc}")
# Sort kernel list for consistent ordering
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
# Generate CMake include file for individual targets
self._generate_cmake_individual_targets(kernel_list)
print(
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
)
def _generate_single_kernel_individual(work_item):
"""Worker function to generate a single individual kernel file"""
(
tile_config,
trait_combo,
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
config_json,
) = work_item
# Create a temporary builder instance for this worker
builder = GemmPreshuffleKernelBuilder(
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
)
try:
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
)
# Create simplified filename without the "gemm_preshuffle_" prefix
# Remove "gemm_preshuffle_" from the beginning of kernel_name for the filename
simplified_name = kernel_name
if simplified_name.startswith("gemm_preshuffle_"):
simplified_name = simplified_name[
len(kernel_name_prefix) + 1 :
] # Remove "gemm_preshuffle_" prefix
# Write individual header file
header_file = working_path / f"gemm_preshuffle_single_{simplified_name}.hpp"
with open(header_file, "w") as f:
f.write(instance_code)
return (kernel_name, trait_combo, tile_config)
except Exception as e:
print(f"Error generating individual kernel: {e}")
return None
def main():
parser = argparse.ArgumentParser(
description="GEMM kernel instance builder with parallel support"
)
parser.add_argument("--working_path", required=True, help="Working directory path")
parser.add_argument(
"--gpu_target",
required=True,
help="GPU target architecture",
)
parser.add_argument(
"--datatype",
required=True,
choices=["fp16", "fp8", "bf16", "bf8"],
help="Data type",
)
parser.add_argument(
"--layout",
required=True,
choices=["rcr"],
help="Matrix layout",
)
parser.add_argument("--config_json", required=True, help="Configuration JSON file")
parser.add_argument(
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
)
parser.add_argument(
"--gen_all_individual",
action="store_true",
help="Generate individual kernel files",
)
parser.add_argument(
"--gen_single", action="store_true", help="Generate a single kernel file"
)
parser.add_argument("--kernel_name", help="Kernel name for single generation")
parser.add_argument(
"--tile_config", help="Tile configuration string for single generation"
)
parser.add_argument(
"--trait_combo", help="Trait combination string for single generation"
)
parser.add_argument(
"--list_kernels",
action="store_true",
help="List kernel configurations without generating files",
)
args = parser.parse_args()
assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], (
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])"
)
layout_parts = args.layout.lower()
assert len(layout_parts) == 3, (
f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
)
assert layout_parts[0] in ["r"] and layout_parts[1] in ["c"], (
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a must be 'r' for row major and matrix_b must be 'c' for column major as it is the only supported layout for preshuffle)"
)
assert layout_parts[2] == "r", (
f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
)
# Create builder
kernel_name_prefix = "gemm_preshuffle"
builder = GemmPreshuffleKernelBuilder(
kernel_name_prefix,
args.working_path,
args.gpu_target,
args.datatype,
args.layout,
args.config_json,
)
if args.list_kernels:
# Fast listing mode - just write kernel list without generating files
builder._list_kernels()
elif args.gen_single:
# Generate a single kernel file
if not args.kernel_name or not args.tile_config or not args.trait_combo:
parser.error(
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
)
# Parse tile config
tile_parts = args.tile_config.split("_")
tile_dims = tile_parts[0].split("x")
warp_dims = tile_parts[1].split("x")
warp_tile_dims = tile_parts[2].split("x")
tile_config = {
"tile_m": int(tile_dims[0]),
"tile_n": int(tile_dims[1]),
"tile_k": int(tile_dims[2]),
"warp_m": int(warp_dims[0]),
"warp_n": int(warp_dims[1]),
"warp_k": int(warp_dims[2]),
"warp_tile_m": int(warp_tile_dims[0]),
"warp_tile_n": int(warp_tile_dims[1]),
"warp_tile_k": int(warp_tile_dims[2]),
}
# Parse trait combo
trait_parts = args.trait_combo.split("_")
trait_combo = (
trait_parts[0], # pipeline
trait_parts[1], # epilogue
trait_parts[2], # scheduler
trait_parts[3] == "True", # pad_m
trait_parts[4] == "True", # pad_n
trait_parts[5] == "True", # pad_k
trait_parts[6] == "True", # persistent
)
# Generate the kernel
builder._generate_kernel_instance(
tile_config,
trait_combo,
)
elif args.gen_all_individual:
# Generate all individual kernel files
builder._generate_all_individual(args.num_workers)
pass
else:
parser.error(
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,289 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "gemm_preshuffle_benchmark.hpp"
class GemmProfiler
{
public:
static GemmProfiler& instance(Setting setting)
{
static GemmProfiler instance{setting};
return instance;
}
// Overload for single kernel benchmarking
void benchmark(GemmProblem& gemm_problem,
std::function<float(const ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>
kernel_func,
KernelConfig& config)
{
// Create a vector with a single callable that returns both name and time
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs&,
const ck_tile::stream_config&)>>
callables;
callables.push_back(
[kernel_func](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {
float time = kernel_func(args, stream);
return std::make_tuple(std::string(KERNEL_NAME), time);
});
benchmark(gemm_problem, callables, config);
}
void benchmark(GemmProblem& gemm_problem,
std::vector<std::function<std::tuple<std::string, float>(
ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables,
KernelConfig& config)
{
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
const CLayout layout_c = CLayout{};
gemm_problem.stride_a_ = ck_tile::get_default_stride(
gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a));
gemm_problem.stride_b_ = ck_tile::get_default_stride(
gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b));
gemm_problem.stride_c_ = ck_tile::get_default_stride(
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c));
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a)));
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
if(setting_.init_method_ == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n);
}
else if(setting_.init_method_ == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
}
else if(setting_.init_method_ == 2)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n);
}
else
{
a_m_k.SetZero();
b_k_n.SetZero();
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
// Reference Verification
ck_tile::HostTensor<CDataType> c_m_n_ref(ck_tile::host_tensor_descriptor(
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
c_m_n_ref.SetZero();
if(setting_.verify_)
{
gemm_host_reference(setting_.verify_,
a_m_k,
b_k_n,
c_m_n_ref,
a_m_k_dev_buf,
b_k_n_dev_buf,
gemm_problem.m_,
gemm_problem.n_,
gemm_problem.k_,
gemm_problem.stride_a_,
gemm_problem.stride_b_,
gemm_problem.stride_c_);
}
// Kerenl Execution
a_m_k_dev_buf.ToDevice(a_m_k.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
for(const auto& callable : callables)
{
ck_tile::index_t N_Warp_Tile = std::get<1>(config.warp_tile_dims);
ck_tile::index_t K_Warp_Tile = std::get<2>(config.warp_tile_dims);
ck_tile::index_t N_Tile = std::get<1>(config.tile_dims);
ck_tile::index_t N_Warp = std::get<1>(config.warp_dims);
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
if(config.permuteN)
{
return shuffle_b_permuteN(b_k_n, N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp);
}
else
{
return shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile);
}
}();
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
ck_tile::GemmHostArgs gemm_args = {
a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
gemm_problem.split_k_,
gemm_problem.m_,
gemm_problem.n_,
gemm_problem.k_,
gemm_problem.stride_a_,
gemm_problem.stride_b_,
gemm_problem.stride_c_,
};
auto kernel_run_result = callable(gemm_args,
ck_tile::stream_config{nullptr,
true,
setting_.log_,
setting_.n_warmup_,
setting_.n_repeat_,
setting_.is_gpu_timer_,
setting_.flush_cache_,
setting_.rotating_count_});
process_result(
gemm_problem, c_m_n_dev_buf, c_m_n_ref, c_m_n_dev_result, kernel_run_result);
}
}
void process_result(const GemmProblem& gemm_problem,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_ref,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
const std::tuple<std::string, float>& kernel_run_result)
{
auto [name, avg_time] = kernel_run_result;
KernelInstance kernel_instance{name, gemm_problem, {-1.0f, -1.0f, -1.0f}};
// compute performance metric
std::size_t flop = std::size_t(2) * gemm_problem.m_ * gemm_problem.n_ * gemm_problem.k_;
std::size_t num_byte = sizeof(ADataType) * gemm_problem.m_ * gemm_problem.k_ +
sizeof(BDataType) * gemm_problem.n_ * gemm_problem.k_ +
sizeof(CDataType) * gemm_problem.m_ * gemm_problem.n_;
// update
kernel_instance.perf_result_.latency_ = avg_time;
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
if(setting_.log_ > 0 && !setting_.json_output_)
{
std::cout << kernel_instance << std::endl;
}
// verify result
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool verified_correct =
!setting_.verify_ ||
compare(name, gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_ref);
if(verified_correct)
{
kernel_instances_.emplace_back(kernel_instance);
}
else
{
std::cout << "Verification failed, skip kernel: " << name << std::endl;
}
// clear tensor
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
}
KernelInstance select_best_instance(Metric metric)
{
if(kernel_instances_.empty())
throw std::runtime_error("Empty instances");
auto kernel_instance = *std::max_element(kernel_instances_.begin(),
kernel_instances_.end(),
[metric](const auto& a, const auto& b) {
return PerformanceResult::compare(
b.perf_result_, a.perf_result_, metric);
});
if(setting_.json_output_)
{
// Output clean JSON only
std::cout << kernel_instance << std::endl;
}
else
{
std::cout << "**********************************" << std::endl;
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
<< "Current kernel performance is: " << kernel_instance << std::endl;
std::cout << "**********************************" << std::endl;
}
if(!setting_.csv_filename_.empty())
{
std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app);
if(!file.is_open())
{
std::cerr << "Warning: Failed to open CSV file for writing." << std::endl;
}
else
{
if(file.tellp() == 0)
{
file << "rocm_version,device_name,"
<< "split_k,m,n,k,stride_a,stride_b,stride_c,"
<< "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c,"
<< "structured_sparsity," << "name,"
<< "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
}
const auto& problem = kernel_instance.problem_;
const auto& name = kernel_instance.name_;
const auto& perf = kernel_instance.perf_result_;
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
<< problem.split_k_ << "," << problem.m_ << "," << problem.n_ << ","
<< problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << ","
<< problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_
<< "," << problem.dtype_acc_ << "," << problem.dtype_c_ << ","
<< problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_
<< "," << problem.structured_sparsity_ << "," << name << "," << std::fixed
<< std::setprecision(4) << perf.latency_ << "," << std::fixed
<< std::setprecision(4) << perf.tflops_ << "," << std::fixed
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric)
<< "\n";
if(!file)
{
std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl;
}
}
}
return kernel_instance;
}
GemmProfiler(const GemmProfiler&) = delete;
GemmProfiler& operator=(const GemmProfiler&) = delete;
private:
~GemmProfiler() { kernel_instances_.clear(); }
GemmProfiler(Setting setting) : setting_(setting) {}
Setting setting_;
std::vector<KernelInstance> kernel_instances_;
};

View File

@@ -0,0 +1,309 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(GEMM_UNIVERSAL_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM Universal (semicolon-separated)")
set(GEMM_UNIVERSAL_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM Universal (semicolon-separated)")
set(GEMM_UNIVERSAL_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
option(ENABLE_CCACHE_GEMM_UNIVERSAL "Enable ccache for GEMM Universal ops compilation" OFF)
# Store the directory path for use in functions
set(GEMM_UNIVERSAL_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
# Function to create individual GEMM Universal targets
function(create_individual_gemm_universal_target datatype layout trait tile_config config_json)
# Use the parent scope GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL variable
if(NOT GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping individual GEMM Universal target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets")
return()
endif()
# Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k
# First split by underscore to get three groups
string(REPLACE "_" ";" config_groups ${tile_config})
list(GET config_groups 0 tile_dims) # e.g., 256x256x32
list(GET config_groups 1 warp_dims) # e.g., 4x1x1
list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16
# Parse tile dimensions
string(REPLACE "x" ";" tile_parts ${tile_dims})
list(GET tile_parts 0 tile_m)
list(GET tile_parts 1 tile_n)
list(GET tile_parts 2 tile_k)
# Parse warp dimensions
string(REPLACE "x" ";" warp_parts ${warp_dims})
list(GET warp_parts 0 warp_m)
list(GET warp_parts 1 warp_n)
list(GET warp_parts 2 warp_k)
# Parse warp tile dimensions
string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims})
list(GET warp_tile_parts 0 warp_tile_m)
list(GET warp_tile_parts 1 warp_tile_n)
list(GET warp_tile_parts 2 warp_tile_k)
set(target_name "benchmark_gemm_universal_${datatype}_${layout}_${trait}_${tile_config}")
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Generate the single instance header for this kernel
set(instance_header "${working_path}/gemm_universal_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
# Add custom command to generate the header file at build time
add_custom_command(
OUTPUT ${instance_header}
COMMAND ${Python3_EXECUTABLE} ${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_universal_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${config_json}
--gen_single
--kernel_name "gemm_universal_${datatype}_${layout}_${trait}_${tile_config}"
--tile_config "${tile_config}"
--trait_combo "${trait}"
--gpu_target "${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}"
DEPENDS ${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_universal_instance_builder.py ${config_json}
COMMENT "Generating ${instance_header}"
)
# Create the executable
add_executable(${target_name}
EXCLUDE_FROM_ALL
${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_benchmark_single.cpp
${instance_header}
)
# Set GPU architectures
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL})
# Set compile definitions
target_compile_definitions(${target_name} PRIVATE
GEMM_UNIVERSAL_SINGLE_INSTANCE_HPP="${instance_header}"
)
# Include directories
target_include_directories(${target_name} PRIVATE
${GEMM_UNIVERSAL_SOURCE_DIR}
${working_path}
)
# Compile options
target_compile_options(${target_name} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
-include ${instance_header}
)
# Add to collection targets
add_dependencies(benchmark_gemm_universal_all ${target_name})
add_dependencies(benchmark_gemm_universal_${datatype} ${target_name})
add_dependencies(benchmark_gemm_universal_${layout} ${target_name})
add_dependencies(benchmark_gemm_universal_${datatype}_${layout} ${target_name})
# Add to trait-specific targets
string(REPLACE "_" ";" trait_parts ${trait})
list(GET trait_parts 0 pipeline)
list(GET trait_parts 1 epilogue)
list(GET trait_parts 2 scheduler)
add_dependencies(benchmark_gemm_universal_${pipeline}_pipeline ${target_name})
add_dependencies(benchmark_gemm_universal_${epilogue}_epilogue ${target_name})
add_dependencies(benchmark_gemm_universal_${scheduler}_scheduler ${target_name})
endfunction()
# Function to build individual GEMM Universal targets
function(build_individual_gemm_universal_targets datatype layout)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Choose config file
# Priority order:
# 1. Environment variable GEMM_UNIVERSAL_CONFIG_FILE
# 2. CMake variable GEMM_UNIVERSAL_CONFIG_FILE
# 3. Default based on layout
# Check environment variable first
if(DEFINED ENV{GEMM_UNIVERSAL_CONFIG_FILE} AND NOT "$ENV{GEMM_UNIVERSAL_CONFIG_FILE}" STREQUAL "")
set(config_filename "$ENV{GEMM_UNIVERSAL_CONFIG_FILE}")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
message(VERBOSE " Using config from environment variable: ${config_filename}")
elseif(NOT "${GEMM_UNIVERSAL_CONFIG_FILE}" STREQUAL "")
# Use CMake variable if set
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_UNIVERSAL_CONFIG_FILE}")
message(VERBOSE " Using custom config: ${GEMM_UNIVERSAL_CONFIG_FILE}")
else()
# Use default config for all layouts
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
message(VERBOSE " Using default config for layout ${layout}")
endif()
# Check if config file exists
if(NOT EXISTS ${json_blob})
message(FATAL_ERROR "Config file not found: ${json_blob}")
endif()
# Determine number of workers for parallel generation
if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL})
set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL})
else()
# Use processor count but limit to avoid memory issues
cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES)
math(EXPR num_workers "${num_cores}")
if(num_workers GREATER 8)
set(num_workers 8)
endif()
endif()
# Generate individual kernel files using parallel version
message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
message(VERBOSE " Working path: ${working_path}")
message(VERBOSE " Config file: ${json_blob}")
message(VERBOSE " Python executable: ${Python3_EXECUTABLE}")
message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py")
# Create working directory first
file(MAKE_DIRECTORY ${working_path})
message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--gpu_target ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}
--list_kernels ")
# First, just list the kernels (fast operation)
message(VERBOSE " Listing kernel configurations...")
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--gpu_target ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}
--list_kernels
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
ERROR_VARIABLE list_error
)
if(NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
endif()
# Read kernel count
if(EXISTS ${working_path}/gemm_universal_kernel_count.txt)
file(READ ${working_path}/gemm_universal_kernel_count.txt kernel_count)
string(STRIP "${kernel_count}" kernel_count)
message(VERBOSE " Found ${kernel_count} kernel configurations")
else()
message(FATAL_ERROR "Kernel count file not found")
endif()
# Read kernel list and create targets
if(EXISTS ${working_path}/gemm_universal_kernel_list.txt)
file(STRINGS ${working_path}/gemm_universal_kernel_list.txt kernel_lines)
foreach(line IN LISTS kernel_lines)
# Parse line: kernel_name|tile_config|trait_combo
string(REPLACE "|" ";" parts "${line}")
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
# Create individual target
create_individual_gemm_universal_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}")
endforeach()
else()
message(FATAL_ERROR "Kernel list file not found")
endif()
endfunction()
# Main build logic - Only individual builds supported
message(VERBOSE "=== Starting Tile Engine GEMM Universal Configuration ===")
message(VERBOSE "GEMM_UNIVERSAL_DATATYPE: ${GEMM_UNIVERSAL_DATATYPE}")
message(VERBOSE "GEMM_UNIVERSAL_LAYOUT: ${GEMM_UNIVERSAL_LAYOUT}")
message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201
set(GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
list(APPEND GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL ${target})
message(VERBOSE " Adding GPU target: ${target}")
endif()
endforeach()
# Skip build if no matching targets found
if(NOT GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine GEMM Universal build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(VERBOSE "Building individual GEMM Universal targets for GPU targets: ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}")
# Enable parallel compilation optimizations
# Set up job pools for better parallel compilation control
set_property(GLOBAL PROPERTY JOB_POOLS
compile_heavy=4 # Limit heavy compilations to prevent OOM
compile_normal=16 # Allow more parallel normal compilations
)
# Enable compiler cache if available and explicitly requested
# Disabled by default due to permission issues in CI environments
if(ENABLE_CCACHE_GEMM_UNIVERSAL)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
message(VERBOSE "Using ccache for faster compilation")
else()
message(WARNING "ccache requested but not found")
endif()
else()
message(VERBOSE "ccache disabled for GEMM Universal ops (use -DENABLE_CCACHE_GEMM_UNIVERSAL=ON to enable)")
endif()
# Create master collection targets
add_custom_target(benchmark_gemm_universal_all)
# Create datatype collection targets
foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE)
add_custom_target(benchmark_gemm_universal_${dt})
endforeach()
# Create layout collection targets
foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT)
add_custom_target(benchmark_gemm_universal_${l})
endforeach()
# Create combined collection targets
foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE)
foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT)
add_custom_target(benchmark_gemm_universal_${dt}_${l})
endforeach()
endforeach()
# Create trait-based collection targets
# These are common trait components used across all GEMM Universal kernels
set(GEMM_UNIVERSAL_PIPELINES "mem;compv3;compv4")
set(GEMM_UNIVERSAL_EPILOGUES "default;cshuffle")
set(GEMM_UNIVERSAL_SCHEDULERS "intrawave;interwave")
foreach(pipeline IN LISTS GEMM_UNIVERSAL_PIPELINES)
add_custom_target(benchmark_gemm_universal_${pipeline}_pipeline)
endforeach()
foreach(epilogue IN LISTS GEMM_UNIVERSAL_EPILOGUES)
add_custom_target(benchmark_gemm_universal_${epilogue}_epilogue)
endforeach()
foreach(scheduler IN LISTS GEMM_UNIVERSAL_SCHEDULERS)
add_custom_target(benchmark_gemm_universal_${scheduler}_scheduler)
endforeach()
# Build individual targets for each datatype/layout combination
foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE)
foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT)
build_individual_gemm_universal_targets(${dt} ${l})
endforeach()
endforeach()
endif()

View File

@@ -0,0 +1,104 @@
{
"tile_config": {
"tile_m": {
"max": 256,
"min": 64,
"step": 64
},
"tile_n": {
"max": 256,
"min": 64,
"step": 64
},
"tile_k": {
"max": 256,
"min": 64,
"step": 64
},
"warp_m": {
"values": [
4,
2,
1
]
},
"warp_n": {
"values": [
4,
2,
1
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
4,
16,
32
]
},
"warp_tile_n": {
"values": [
16,
32,
64
]
},
"warp_tile_k": {
"values": [
8,
16,
32,
64,
128
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3",
"compv4",
"mem"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
]
},
"epilogue": {
"values": [
"cshuffle",
"default"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
},
"persistent": {
"values": [
false,
true
]
}
},
"k_block_per_cu": 1
}

View File

@@ -0,0 +1,87 @@
{
"tile_config": {
"tile_m": {
"values": [
128
]
},
"tile_n": {
"values": [
128
]
},
"tile_k": {
"values": [
64
]
},
"warp_m": {
"values": [
4
]
},
"warp_n": {
"values": [
1
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
16
]
},
"warp_tile_n": {
"values": [
16
]
},
"warp_tile_k": {
"values": [
16
]
}
},
"trait_config": {
"pipeline": {
"values": [
"mem"
]
},
"scheduler": {
"values": [
"intrawave"
]
},
"epilogue": {
"values": [
"default"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
},
"persistent": {
"values": [
true
]
}
},
"k_block_per_cu": 2
}

View File

@@ -1,4 +1,3 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -11,12 +11,12 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/common/utils.hpp"
#include "gemm_profiler.hpp"
#include "gemm_common.hpp"
// The kernel header is included via the compile command line with -include flag
// It defines SelectedKernel struct and KERNEL_NAME
// DataTypeTraits are now defined in gemm_common.hpp
// Create argument parser
inline auto create_args(int argc, char* argv[])
@@ -77,12 +77,12 @@ inline auto create_args(int argc, char* argv[])
void benchmark_single(const ck_tile::ArgParser& arg_parser)
{
// Use ck_tile::DataTypeTraits to get the actual type names from the generated header
// Use DataTypeTraits to get the actual type names from the generated header
// The generated header defines ADataType, BDataType, AccDataType, CDataType
std::string dtype_a = ck_tile::DataTypeTraits<ADataType>::name;
std::string dtype_b = ck_tile::DataTypeTraits<BDataType>::name;
std::string dtype_acc = ck_tile::DataTypeTraits<AccDataType>::name;
std::string dtype_c = ck_tile::DataTypeTraits<CDataType>::name;
std::string dtype_a = DataTypeTraits<ADataType>::name;
std::string dtype_b = DataTypeTraits<BDataType>::name;
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
std::string dtype_c = DataTypeTraits<CDataType>::name;
// Layout names from the layout types
std::string layout_a = ALayout::name;

View File

@@ -0,0 +1,100 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
//[TODO] This can be moved to commons
// DataTypeTraits for all supported types
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <>
struct DataTypeTraits<ck_tile::int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
// Helper function to determine if a layout is row-major
template <typename Layout>
constexpr auto is_row_major(Layout)
{
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// Structure to hold kernel traits for dispatcher
struct KernelTraits
{
std::string pipeline; // compv3, compv4, mem
std::string scheduler; // intrawave, interwave
std::string epilogue; // cshuffle, default
bool pad_m;
bool pad_n;
bool pad_k;
bool persistent;
// Constructor with defaults
KernelTraits()
: pipeline("compv3"),
scheduler("intrawave"),
epilogue("cshuffle"),
pad_m(false),
pad_n(false),
pad_k(false),
persistent(false)
{
}
};

View File

@@ -0,0 +1,295 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import os
import argparse
import importlib.util
import multiprocessing
import concurrent.futures
def _import_gemm_kernel_builder():
"""Import validation utilities from commons directory."""
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
# Load the module dynamically
spec = importlib.util.spec_from_file_location(
"gemm_instance_builder",
os.path.join(parent_dir, "gemm_instance_builder.py"),
)
gemm_builder_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(gemm_builder_module)
return gemm_builder_module.GemmKernelBuilder
GemmKernelBuilder = _import_gemm_kernel_builder()
class GemmUniversalKernelBuilder(GemmKernelBuilder):
def __init__(
self,
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
config_json=None,
):
super().__init__(
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
)
def _generate_all_individual(self, num_workers=None):
"""Generate individual kernel files for separate compilation with parallel processing"""
if num_workers is None:
num_workers = min(
multiprocessing.cpu_count(), 8
) # Limit to avoid memory issues
tile_configs = self._get_tile_configs()
trait_combos = self._generate_trait_combinations()
# Prepare work items for parallel processing
work_items = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
work_items.append(
(
tile_config,
trait_combo,
self.kernel_name_prefix,
self.working_path,
self.gpu_target,
self.datatype,
self.layout,
self.config_json,
)
)
print(
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
)
print(f" Tile configs: {len(tile_configs)}")
print(f" Trait combinations: {len(trait_combos)}")
print(f" Total kernels: {len(work_items)}")
# Show first few work items for debugging
if work_items:
print(" First work item example:")
tile_config, trait_combo = work_items[0][:2]
print(f" Tile config: {tile_config}")
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
# Process work items in parallel
kernel_list = []
completed = 0
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
# Submit all work items
print(f" Submitting {len(work_items)} tasks to executor...")
future_to_item = {
executor.submit(_generate_single_kernel_individual, item): item
for item in work_items
}
print(" All tasks submitted, waiting for completion...")
# Collect results with progress reporting
for future in concurrent.futures.as_completed(future_to_item):
completed += 1
if completed % 100 == 0 or completed == len(work_items):
print(
f" Progress: {completed}/{len(work_items)} kernels generated"
)
try:
result = future.result()
if result:
kernel_list.append(result)
except Exception as exc:
item = future_to_item[future]
print(f"Kernel generation failed for {item}: {exc}")
# Sort kernel list for consistent ordering
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
# Generate CMake include file for individual targets
self._generate_cmake_individual_targets(kernel_list)
print(
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
)
def _generate_single_kernel_individual(work_item):
"""Worker function to generate a single individual kernel file"""
(
tile_config,
trait_combo,
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
config_json,
) = work_item
# Create a temporary builder instance for this worker
builder = GemmUniversalKernelBuilder(
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
)
try:
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
)
# Create simplified filename without the "gemm_universal_" prefix
# Remove "gemm_universal_" from the beginning of kernel_name for the filename
simplified_name = kernel_name
if simplified_name.startswith("gemm_universal_"):
simplified_name = simplified_name[
len(kernel_name_prefix) + 1 :
] # Remove "gemm_universal" prefix
# Write individual header file
header_file = working_path / f"gemm_universal_single_{simplified_name}.hpp"
with open(header_file, "w") as f:
f.write(instance_code)
return (kernel_name, trait_combo, tile_config)
except Exception as e:
print(f"Error generating individual kernel: {e}")
return None
def main():
parser = argparse.ArgumentParser(
description="GEMM Universal kernel instance builder with parallel support"
)
parser.add_argument("--working_path", required=True, help="Working directory path")
parser.add_argument(
"--gpu_target",
required=True,
help="GPU target architecture",
)
parser.add_argument(
"--datatype",
required=True,
choices=["fp16", "fp8", "bf16", "bf8"],
help="Data type",
)
parser.add_argument(
"--layout",
required=True,
choices=["rcr", "rrr", "ccr", "crr"],
help="Matrix layout",
)
parser.add_argument("--config_json", help="Configuration JSON file")
parser.add_argument(
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
)
parser.add_argument(
"--gen_all_individual",
action="store_true",
help="Generate individual kernel files",
)
parser.add_argument(
"--gen_single", action="store_true", help="Generate a single kernel file"
)
parser.add_argument("--kernel_name", help="Kernel name for single generation")
parser.add_argument(
"--tile_config", help="Tile configuration string for single generation"
)
parser.add_argument(
"--trait_combo", help="Trait combination string for single generation"
)
parser.add_argument(
"--list_kernels",
action="store_true",
help="List kernel configurations without generating files",
)
args = parser.parse_args()
assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], (
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])"
)
layout_parts = args.layout.lower()
assert len(layout_parts) == 3, (
f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
)
assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], (
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)"
)
assert layout_parts[2] == "r", (
f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
)
kernel_name_prefix = "gemm_universal"
builder = GemmUniversalKernelBuilder(
kernel_name_prefix,
args.working_path,
args.gpu_target,
args.datatype,
args.layout,
args.config_json,
)
if args.list_kernels:
builder._list_kernels()
elif args.gen_single:
# Generate a single kernel file input validation
if not args.kernel_name or not args.tile_config or not args.trait_combo:
parser.error(
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
)
# Parse tile config
tile_parts = args.tile_config.split("_")
tile_dims = tile_parts[0].split("x")
warp_dims = tile_parts[1].split("x")
warp_tile_dims = tile_parts[2].split("x")
tile_config = {
"tile_m": int(tile_dims[0]),
"tile_n": int(tile_dims[1]),
"tile_k": int(tile_dims[2]),
"warp_m": int(warp_dims[0]),
"warp_n": int(warp_dims[1]),
"warp_k": int(warp_dims[2]),
"warp_tile_m": int(warp_tile_dims[0]),
"warp_tile_n": int(warp_tile_dims[1]),
"warp_tile_k": int(warp_tile_dims[2]),
}
# Parse trait combo
trait_parts = args.trait_combo.split("_")
trait_combo = (
trait_parts[0], # pipeline
trait_parts[1], # epilogue
trait_parts[2], # scheduler
trait_parts[3] == "True", # pad_m
trait_parts[4] == "True", # pad_n
trait_parts[5] == "True", # pad_k
trait_parts[6] == "True", # persistent
)
# Generate the kernel
builder._generate_kernel_instance(
tile_config,
trait_combo,
)
elif args.gen_all_individual:
# Generate all individual kernel files
builder._generate_all_individual(args.num_workers)
else:
parser.error(
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,950 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import logging
from typing import Tuple, List
GEMM_PIPELINES = ["mem", "compv3", "compv4"]
GEMM_PRESHUFFLE_PIPELINES = ["preshufflev2"]
LAYOUT_MAP = {
"r": "ck_tile::tensor_layout::gemm::RowMajor",
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
}
ELEMENT_SIZE_MAP = {
"fp16": 2,
"bf16": 2,
"int8": 1,
"fp8": 1,
"bf8": 1,
"int4": 0.5,
"int32": 4,
"fp32": 4,
"fp64": 8,
}
WARP_SUPPORTED_COMBINATIONS = {
"gfx90a": [
[1, 4, 1],
[2, 2, 1],
[4, 1, 1],
],
"gfx942": [
[1, 4, 1],
[2, 2, 1],
[4, 1, 1],
],
"gfx950": [
[1, 4, 1],
[2, 2, 1],
[4, 1, 1],
],
"gfx1201": [
[2, 4, 1],
[1, 8, 1],
[8, 1, 1],
[4, 2, 1],
],
}
GEMM_PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS = {
"gfx90a": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]],
},
"gfx942": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
},
"gfx950": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"fp8_fp8_fp16": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 32],
[16, 16, 64],
[16, 16, 128],
[32, 32, 64],
],
"bf8_bf8_fp16": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 64],
[16, 16, 32],
[16, 16, 128],
[32, 32, 64],
],
},
}
GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = {
"gfx90a": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]],
},
"gfx942": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
},
"gfx950": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 32],
[16, 16, 64],
[16, 16, 128],
[32, 32, 64],
],
"bf8_bf8_fp16": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 64],
[16, 16, 32],
[16, 16, 128],
[32, 32, 64],
],
},
"gfx1201": { # Check how to handle for GEMM and Multi D
"fp16_fp16_fp16": [
[16, 16, 16],
],
},
}
TRAIT_UNSUPPORTED_COMBINATIONS = {
("compv3", "cshuffle", "interwave"),
("compv3", "default", "interwave"),
("compv4", "cshuffle", "interwave"),
("compv4", "default", "interwave"),
}
def element_size(data_type: str) -> float:
"""Calculate the size (in bytes) of a single element for given data type."""
data_type = data_type.lower()
if data_type not in ELEMENT_SIZE_MAP:
raise ValueError(f"Unsupported data type: {data_type}")
return ELEMENT_SIZE_MAP[data_type]
def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool:
"""Check if a trait combination is valid."""
return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS
def validate_warp_configuration(
warp_m: int,
warp_n: int,
warp_k: int,
gpu_name: str,
) -> bool:
"""Validate warp configuration."""
current_combination = [warp_m, warp_n, warp_k]
allowed_combinations = WARP_SUPPORTED_COMBINATIONS.get(gpu_name, {})
if not allowed_combinations:
# If GPU not recognized, try to be permissive but log warning
logging.warning(f"No warp_[m/n/k] combinations found for GPU: {gpu_name}")
return True
# Check if current combination is in the allowed list
if current_combination not in allowed_combinations:
return False
return True
def validate_dimension_alignment(
tile_m: int,
tile_n: int,
tile_k: int,
warp_m: int,
warp_n: int,
warp_k: int,
warp_tile_m: int,
warp_tile_n: int,
warp_tile_k: int,
) -> Tuple[bool, List[str]]:
"""Check if tile dimensions are properly aligned with warp dimensions."""
alignment_issues = []
if tile_m % (warp_m * warp_tile_m) != 0:
alignment_issues.append(
f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}"
)
if tile_n % (warp_n * warp_tile_n) != 0:
alignment_issues.append(
f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}"
)
if tile_k % (warp_k * warp_tile_k) != 0:
alignment_issues.append(
f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}"
)
return len(alignment_issues) == 0, alignment_issues
def validate_lds_capacity(
tile_m: int,
tile_n: int,
tile_k: int,
a_datatype: str,
b_datatype: str,
pipeline: str,
) -> Tuple[bool, str]:
"""Validate LDS capacity requirements."""
matrix_a_size = (tile_m * tile_k) * element_size(a_datatype)
matrix_b_size = (tile_n * tile_k) * element_size(b_datatype)
total_tile_in_lds = matrix_a_size + matrix_b_size
max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16
if total_tile_in_lds > max_tile_size:
error_msg = (
f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > "
f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n"
f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
)
return False, error_msg
return True, ""
def validate_gemm_warp_tile_combination(
warp_tile_m: int,
warp_tile_n: int,
warp_tile_k: int,
a_datatype: str,
b_datatype: str,
c_datatype: str,
gpu_name: str,
) -> Tuple[bool, str]:
"""Validate warp tile combination against GPU-specific supported combinations."""
# Construct the key for looking up supported combinations
warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}"
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
# Check if we have GPU-specific combinations
gpu_warp_tile_combinations = GEMM_WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {})
if not gpu_warp_tile_combinations:
# If GPU not recognized, try to be permissive but log warning
logging.warning(f"No warp tile combinations found for GPU: {gpu_name}")
return True, ""
# Check if we have combinations for this data type combination
allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, [])
if not allowed_combinations:
# For data type combinations not in the list, be permissive
logging.debug(
f"No warp tile combinations found for data types: {warp_tile_key}"
)
return True, ""
# Check if current combination is in the allowed list
if current_combination not in allowed_combinations:
error_msg = (
f"Invalid warp tile combination: {current_combination} not in allowed list. "
f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}"
)
return False, error_msg
return True, ""
def validate_gemm_preshuffle_warp_tile_combination(
warp_tile_m: int,
warp_tile_n: int,
warp_tile_k: int,
a_datatype: str,
b_datatype: str,
c_datatype: str,
gpu_name: str,
) -> Tuple[bool, str]:
"""Validate warp tile combination against GPU-specific supported combinations."""
# Construct the key for looking up supported combinations
warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}"
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
# Check if we have GPU-specific combinations
gpu_warp_tile_combinations = GEMM_PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS.get(
gpu_name, {}
)
if not gpu_warp_tile_combinations:
# If GPU not recognized, try to be permissive but log warning
logging.warning(f"No warp tile combinations found for GPU: {gpu_name}")
return True, ""
# Check if we have combinations for this data type combination
allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, [])
if not allowed_combinations:
# For data type combinations not in the list, be permissive
logging.debug(
f"No warp tile combinations found for data types: {warp_tile_key}"
)
return True, ""
# Check if current combination is in the allowed list
if current_combination not in allowed_combinations:
error_msg = (
f"Invalid warp tile combination: {current_combination} not in allowed list. "
f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}"
)
return False, error_msg
return True, ""
def is_tile_config_valid(
tile_m: int,
tile_n: int,
tile_k: int,
warp_m: int,
warp_n: int,
warp_k: int,
warp_tile_m: int,
warp_tile_n: int,
warp_tile_k: int,
a_datatype: str,
b_datatype: str,
c_datatype: str,
pipeline: str,
layout: str,
gpu_target: str,
) -> bool:
"""
Comprehensive tile configuration validation.
Returns True if configuration is valid, False otherwise.
"""
# Basic sanity checks
if tile_m <= 0 or tile_n <= 0 or tile_k <= 0:
return False
if warp_m <= 0 or warp_n <= 0 or warp_k <= 0:
return False
if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0:
return False
# Check that warp tiles fit within block tiles
if warp_m * warp_tile_m > tile_m:
return False
if warp_n * warp_tile_n > tile_n:
return False
if warp_k * warp_tile_k > tile_k:
return False
# Validate warp configuration
if not validate_warp_configuration(warp_m, warp_n, warp_k, gpu_target):
logging.debug(
f"Invalid warp configuration: warp_m({warp_m}), warp_n({warp_n}), warp_k({warp_k})"
)
return False
# Validate dimension alignment
is_aligned, alignment_issues = validate_dimension_alignment(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
)
if not is_aligned:
logging.debug(
f"Dimension alignment failed: {', '.join(alignment_issues)}. "
f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by "
f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
)
return False
# Validate LDS capacity
lds_valid, lds_error = validate_lds_capacity(
tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline
)
if not lds_valid:
logging.debug(f"LDS validation failed: {lds_error}")
return False
if pipeline in GEMM_PIPELINES:
gemm_valid, gemm_valid_error = validate_gemm(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
a_datatype,
b_datatype,
c_datatype,
pipeline,
layout,
gpu_target,
)
if not gemm_valid:
logging.debug(f"GEMM validation failed: {gemm_valid_error}")
return False
# Validate warp tile combination
warp_tile_valid, warp_tile_error = validate_gemm_warp_tile_combination(
warp_tile_m,
warp_tile_n,
warp_tile_k,
a_datatype,
b_datatype,
c_datatype,
gpu_target,
)
if not warp_tile_valid:
logging.debug(f"Warp tile validation failed: {warp_tile_error}")
return False
elif pipeline in GEMM_PRESHUFFLE_PIPELINES:
preshuffle_valid, preshuffle_valid_error = validate_gemm_preshuffle(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
a_datatype,
b_datatype,
c_datatype,
pipeline,
layout,
gpu_target,
)
if not preshuffle_valid:
logging.debug(
f"GEMM Preshuffle validation failed: {preshuffle_valid_error}"
)
return False
# Validate warp tile combination
warp_tile_valid, warp_tile_error = (
validate_gemm_preshuffle_warp_tile_combination(
warp_tile_m,
warp_tile_n,
warp_tile_k,
a_datatype,
b_datatype,
c_datatype,
gpu_target,
)
)
if not warp_tile_valid:
logging.debug(f"Warp tile validation failed: {warp_tile_error}")
return False
return True
# [TODO] Handle this while moving code to commons Add more datatype to this function if needed
def get_dtype_string(datatype: str) -> str:
"""Get C++ type string for datatype"""
dtype_map = {
"fp16": "ck_tile::fp16_t",
"fp8": "ck_tile::fp8_t",
"bf8": "ck_tile::bf8_t",
"bf16": "ck_tile::bf16_t",
"fp32": "float",
"fp64": "double",
}
return dtype_map.get(datatype, "float")
def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]:
"""
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'.
"""
code = str(layout_code).strip().lower()
a_layout = LAYOUT_MAP[code[0]]
b_layout = LAYOUT_MAP[code[1]]
c_layout = LAYOUT_MAP[code[2]]
return a_layout, b_layout, c_layout
def get_abcd_layouts(layout_code: str) -> Tuple[str, str, str, List[str]]:
"""
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcrr', 'ccrr', 'crrr', 'rrrr'.
"""
code = str(layout_code).strip().lower()
a_layout = LAYOUT_MAP[code[0]]
b_layout = LAYOUT_MAP[code[1]]
c_layout = LAYOUT_MAP[code[2]]
d0_layout = LAYOUT_MAP[code[3]]
d1_layout = LAYOUT_MAP[code[3]]
return a_layout, b_layout, c_layout, [d0_layout, d1_layout]
def validate_whole_wg_cover_configuration(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
layout,
a_datatype,
b_datatype,
) -> Tuple[bool, str]:
# Validate whole workgroup cover configuration
warp_size = 64
NumWarps = warp_m * warp_n * warp_k
BlockSize = NumWarps * warp_size
XPerTile = 0
YPerTile = 0
vector_load_size = 0
# A matrix validation
if layout[0] == "r":
vector_load_size = get_global_vector_load_size(
BlockSize, tile_k, a_datatype, tile_m, tile_k
)
XPerTile = tile_k
YPerTile = tile_m
elif layout[0] == "c":
vector_load_size = get_global_vector_load_size(
BlockSize, tile_k, a_datatype, tile_m, tile_m
)
# Validate distribution
XPerTile = tile_k
YPerTile = tile_m
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
)
if not wg_cover_core_valid:
logging.debug(
f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}"
)
return False, wg_cover_core_error
XPerTile = tile_m
YPerTile = tile_k
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
)
if not wg_cover_core_valid:
logging.debug(
f"whole workgroup cover failed for Matrix A: {wg_cover_core_error}"
)
return False, wg_cover_core_error
# B matrix validation
if layout[1] == "r":
vector_load_size = get_global_vector_load_size(
BlockSize, tile_k, b_datatype, tile_n, tile_n
)
# Validate distribution
XPerTile = tile_k
YPerTile = tile_n
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
)
if not wg_cover_core_valid:
logging.debug(
f"whole workgroup cover failed for Matrix B distribution: {wg_cover_core_error}"
)
return False, wg_cover_core_error
XPerTile = tile_n
YPerTile = tile_k
elif layout[1] == "c":
XPerTile = tile_k
YPerTile = tile_n
vector_load_size = get_global_vector_load_size(
BlockSize, tile_k, b_datatype, tile_n, tile_k
)
wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation(
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
)
if not wg_cover_core_valid:
logging.debug(
f"whole workgroup cover failed for Matrix B: {wg_cover_core_error}"
)
return False, wg_cover_core_error
return True, ""
def wg_cover_core_validation(
XPerTile: int,
YPerTile: int,
BlockSize: int,
vector_load_size: int,
warp_size: int,
) -> Tuple[bool, str]:
if XPerTile % vector_load_size != 0:
return False, "XPerTile is not divisible by vector_load_size"
num_warps = BlockSize / warp_size
LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size)
X1 = LargestVec if vector_load_size > LargestVec else vector_load_size
X0 = XPerTile / X1
Y1 = warp_size // X0
if X0 * Y1 != warp_size:
return False, "X0 * Y1 != warp_size"
return True, ""
def get_global_vector_load_size(
BlockSize: int,
KPerBlock: int,
DataType: str,
MNPerBlock: int,
XPerTile: int,
) -> int:
elements_per_thread = MNPerBlock * KPerBlock / BlockSize
PackedSize = 1
if (
PackedSize == 2
and XPerTile % (PackedSize * 32 / element_size(DataType)) == 0
and elements_per_thread % (PackedSize * 32 / element_size(DataType)) == 0
):
return PackedSize * 32 / element_size(DataType)
elif (
XPerTile % (PackedSize * 16 / element_size(DataType)) == 0
and elements_per_thread % (PackedSize * 16 / element_size(DataType)) == 0
):
return int(PackedSize * 16 / element_size(DataType))
elif (
XPerTile % (PackedSize * 8 / element_size(DataType)) == 0
and elements_per_thread % (PackedSize * 8 / element_size(DataType)) == 0
):
return int(PackedSize * 8 / element_size(DataType))
elif (
element_size(DataType) >= PackedSize * 4
and XPerTile % (PackedSize * 4 / element_size(DataType)) == 0
and elements_per_thread % (PackedSize * 4 / element_size(DataType)) == 0
):
return int(PackedSize * 4 / element_size(DataType))
elif (
element_size(DataType) >= PackedSize * 2
and XPerTile % (PackedSize * 2 / element_size(DataType)) == 0
and elements_per_thread % (PackedSize * 2 / element_size(DataType)) == 0
):
return int(PackedSize * 2 / element_size(DataType))
else:
return PackedSize
def validate_gemm(
tile_m: int,
tile_n: int,
tile_k: int,
warp_m: int,
warp_n: int,
warp_k: int,
warp_tile_m: int,
warp_tile_n: int,
warp_tile_k: int,
a_datatype: str,
b_datatype: str,
c_datatype: str,
pipeline: str,
layout: str,
gpu_target: str,
trait_name: str = None,
) -> bool:
# GEMM Validation
# Validate whole workgroup cover configuration
whole_workgroup_cover_valid, whole_workgroup_cover_error = (
validate_whole_wg_cover_configuration(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
layout,
a_datatype,
b_datatype,
)
)
if not whole_workgroup_cover_valid:
logging.debug(
f"Whole workgroup cover configuration validation failed: {whole_workgroup_cover_error}"
)
return False, whole_workgroup_cover_error
return True, ""
def validate_gemm_preshuffle(
tile_m: int,
tile_n: int,
tile_k: int,
warp_m: int,
warp_n: int,
warp_k: int,
warp_tile_m: int,
warp_tile_n: int,
warp_tile_k: int,
a_datatype: str,
b_datatype: str,
c_datatype: str,
pipeline: str,
layout: str,
gpu_target: str,
trait_name: str = None,
) -> bool:
# Preshuffle Validations
# Validate vector load alignment
m_iter_per_warp = tile_m / (warp_m * warp_tile_m)
vector_valid, vector_error = validate_vector_load_alignment(
warp_tile_m,
warp_tile_k,
a_datatype,
m_iter_per_warp,
wave_size=64,
vector_load_size=16,
)
if not vector_valid:
logging.debug(f"Vector load alignment failed: {vector_error}")
return False, "vector load alignment error"
# Validate M0, M1, M2 configuration for matrix A row-major layout
m0_m1_m2_valid, m0_m1_m2_error = validate_m0_m1_m2_configuration(
tile_m,
tile_k,
warp_m,
warp_n,
warp_k,
a_datatype,
vector_load_size=16,
warp_size=64,
)
if not m0_m1_m2_valid:
logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}")
return False, m0_m1_m2_error
return True, ""
def validate_vector_load_alignment(
wg_m: int,
wg_k: int,
a_datatype: str,
m_iter_per_warp: int,
wave_size: int,
vector_load_size: int,
) -> Tuple[bool, str]:
try:
# Calculate the memory access pattern size
a_element_size = element_size(a_datatype)
access_size = (wg_m * wg_k * a_element_size * m_iter_per_warp) / wave_size
# Check if it's aligned to vector load size
if access_size % vector_load_size != 0:
error_msg = (
f"Vector load alignment violation: "
f"({wg_m} * {wg_k} * {a_element_size} * {m_iter_per_warp} / {wave_size}) "
f"% {vector_load_size} = {access_size % vector_load_size} != 0. "
f"Access size: {access_size} bytes"
)
return False, error_msg
return True, ""
except Exception as e:
return False, f"Error in vector load validation: {str(e)}"
def validate_m0_m1_m2_configuration(
tile_m: int,
tile_k: int,
warp_m: int,
warp_n: int,
warp_k: int,
a_datatype: str,
vector_load_size: int = 16,
warp_size: int = 64,
) -> Tuple[bool, str]:
"""
Validate M0, M1, M2 configuration for matrix A row-major layout.
This ensures proper memory access pattern alignment.
"""
try:
# Validation for A as row-major
MPerBlock = tile_m
# Calculate K1 using element size
K1 = vector_load_size / element_size(a_datatype)
# Check if K1 is valid (must be integer)
if K1 != int(K1):
return (
False,
f"K1 = {K1} is not an integer. vector_load_size({vector_load_size}) must be divisible by element_size({a_datatype})",
)
K1 = int(K1)
# Calculate K0
if tile_k % K1 != 0:
return False, f"tile_k({tile_k}) must be divisible by K1({K1})"
K0 = tile_k // K1
# Calculate M2
if warp_size % K0 != 0:
return False, f"warp_size({warp_size}) must be divisible by K0({K0})"
M2 = warp_size // K0
# Calculate number of warps and block size
NumWarps = warp_m * warp_n * warp_k
BlockSize = NumWarps * warp_size
# Calculate M0 (assuming get_warp_size() returns warp_size)
M0 = BlockSize // warp_size # This should equal NumWarps
# Calculate M1
if (M2 * M0) == 0:
return False, f"M2({M2}) * M0({M0}) cannot be zero"
if MPerBlock % (M2 * M0) != 0:
return (
False,
f"MPerBlock({MPerBlock}) must be divisible by M2({M2}) * M0({M0}) = {M2 * M0}",
)
M1 = MPerBlock // (M2 * M0)
# Validate the assertion: M0 * M1 * M2 == MPerBlock
calculated_m_per_block = M0 * M1 * M2
if calculated_m_per_block != MPerBlock:
error_msg = (
f"Incorrect M0, M1, M2 configuration! "
f"M0({M0}) * M1({M1}) * M2({M2}) = {calculated_m_per_block} != MPerBlock({MPerBlock}). "
f"Configuration: K0={K0}, K1={K1}, NumWarps={NumWarps}, BlockSize={BlockSize}"
)
return False, error_msg
return True, ""
except ZeroDivisionError as e:
return False, f"Division by zero in M0/M1/M2 calculation: {str(e)}"
except Exception as e:
return False, f"Error in M0/M1/M2 validation: {str(e)}"