[TILE ENGINE] Restructure to Base class of GEMM (#3434)

This commit is contained in:
Thrupti Raj Lakshmana Gowda
2025-12-19 09:53:56 -06:00
committed by GitHub
parent 0fd2b2f045
commit e22622f0ec
41 changed files with 2246 additions and 3458 deletions

View File

@@ -5,4 +5,6 @@ include_directories(BEFORE
${CMAKE_CURRENT_LIST_DIR}/include
)
add_subdirectory(ops)
add_subdirectory(ops/gemm)
add_subdirectory(ops/gemm_streamk)

View File

@@ -1,7 +0,0 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_subdirectory(gemm)
add_subdirectory(gemm_multi_d)
add_subdirectory(gemm_preshuffle)
add_subdirectory(gemm_streamk)

View File

@@ -1,105 +0,0 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# Test script for tile engine GEMM benchmarks
# This script demonstrates how to run the new individual benchmark executables
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# Find the build directory
if [ -z "$1" ]; then
# Try to find build directory automatically
BUILD_DIR=$(find /root/workspace/composable_kernel -name "test_gemm_fix" -type d 2>/dev/null | head -1)
if [ -z "$BUILD_DIR" ]; then
echo -e "${RED}Error: Could not find build directory. Please provide it as first argument.${NC}"
echo "Usage: $0 <build_directory>"
exit 1
fi
else
BUILD_DIR="$1"
fi
echo -e "${GREEN}Using build directory: $BUILD_DIR${NC}"
# Check if bin directory exists
if [ ! -d "$BUILD_DIR/bin" ]; then
echo -e "${RED}Error: bin directory not found in $BUILD_DIR${NC}"
exit 1
fi
# Find all benchmark executables
echo -e "${YELLOW}Finding benchmark executables...${NC}"
BENCHMARKS=$(find "$BUILD_DIR/bin" -name "benchmark_gemm_*" -type f 2>/dev/null)
if [ -z "$BENCHMARKS" ]; then
echo -e "${RED}No benchmark executables found in $BUILD_DIR/bin${NC}"
echo "Please build some benchmarks first with:"
echo " cd $BUILD_DIR"
echo " make benchmark_gemm_<kernel_name>"
exit 1
fi
# Count benchmarks
NUM_BENCHMARKS=$(echo "$BENCHMARKS" | wc -l)
echo -e "${GREEN}Found $NUM_BENCHMARKS benchmark executable(s)${NC}"
# Test sizes
SIZES=(512 1024 2048)
# Results file
RESULTS_FILE="benchmark_results_$(date +%Y%m%d_%H%M%S).csv"
echo -e "${YELLOW}Running benchmarks...${NC}"
echo "Results will be saved to: $RESULTS_FILE"
# Run each benchmark
COUNTER=0
for BENCH in $BENCHMARKS; do
COUNTER=$((COUNTER + 1))
BENCH_NAME=$(basename "$BENCH")
echo -e "\n${GREEN}[$COUNTER/$NUM_BENCHMARKS] Running: $BENCH_NAME${NC}"
for SIZE in "${SIZES[@]}"; do
echo -e " Testing size: ${SIZE}x${SIZE}x${SIZE}"
# Run with verification
"$BENCH" -m=$SIZE -n=$SIZE -k=$SIZE -verify=2 -warmup=10 -repeat=20 \
-csv_filename="$RESULTS_FILE" -csv_format=simple \
2>&1 | grep -E "(Time:|Performance:|Verification:|Error)"
if [ ${PIPESTATUS[0]} -ne 0 ]; then
echo -e " ${RED}Benchmark failed!${NC}"
fi
done
done
echo -e "\n${GREEN}Benchmark testing complete!${NC}"
echo "Results saved to: $RESULTS_FILE"
# Show summary if CSV file exists
if [ -f "$RESULTS_FILE" ]; then
echo -e "\n${YELLOW}Summary of results:${NC}"
echo "Number of tests: $(tail -n +2 "$RESULTS_FILE" | wc -l)"
echo "Successful tests: $(grep -c "true" "$RESULTS_FILE")"
echo "Failed tests: $(grep -c "false" "$RESULTS_FILE")"
fi
# Example of running a specific benchmark with different options
echo -e "\n${YELLOW}Example commands for manual testing:${NC}"
echo "# Basic run:"
echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024"
echo ""
echo "# With CPU verification:"
echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024 -verify=1"
echo ""
echo "# JSON output for parsing:"
echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024 -json_output=true"
echo ""
echo "# Performance testing with TFLOPS metric:"
echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=4096 -n=4096 -k=4096 -warmup=100 -repeat=200 -metric=1"

View File

@@ -1,144 +0,0 @@
#!/usr/bin/env python
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Test script to verify that the validation logic is working correctly.
"""
from validation_utils import (
is_tile_config_valid,
is_trait_combination_valid,
validate_warp_tile_combination,
)
def test_warp_tile_validation():
"""Test warp tile combination validation"""
print("Testing warp tile combination validation...")
# Get GPU name
gpu_name = "gfx90a"
# Test cases for fp16
test_cases = [
# (warp_tile_m, warp_tile_n, warp_tile_k, expected_valid)
([4, 64, 8], False), # Invalid - not in supported list
([4, 64, 16], True), # Valid
([32, 32, 8], True), # Valid
([16, 16, 16], True), # Valid
([32, 32, 16], True), # Valid
([16, 16, 32], True), # Valid
([64, 4, 16], True), # Valid
([128, 128, 128], False), # Invalid - too large
]
print("\nTesting fp16 warp tile combinations:")
for (warp_tile_m, warp_tile_n, warp_tile_k), expected in test_cases:
valid, msg = validate_warp_tile_combination(
warp_tile_m, warp_tile_n, warp_tile_k, "fp16", "fp16", "fp16", gpu_name
)
status = "PASS" if valid == expected else "FAIL"
print(f" [{warp_tile_m}, {warp_tile_n}, {warp_tile_k}]: {valid} - {status}")
if not valid and msg:
print(f" Reason: {msg}")
def test_trait_combinations():
"""Test trait combination validation"""
print("\n\nTesting trait combination validation...")
test_cases = [
# (pipeline, epilogue, scheduler, expected_valid)
("mem", "default", "intrawave", True),
("mem", "cshuffle", "intrawave", True),
("compv3", "default", "interwave", False), # Invalid combination
("compv3", "cshuffle", "interwave", False), # Invalid combination
("compv4", "default", "interwave", False), # Invalid combination
("compv4", "cshuffle", "interwave", False), # Invalid combination
("compv3", "default", "intrawave", True),
("compv4", "cshuffle", "intrawave", True),
]
print("\nTesting trait combinations:")
for pipeline, epilogue, scheduler, expected in test_cases:
valid = is_trait_combination_valid(pipeline, epilogue, scheduler)
status = "PASS" if valid == expected else "FAIL"
print(f" {pipeline}-{epilogue}-{scheduler}: {valid} - {status}")
def test_full_tile_config_validation():
"""Test full tile configuration validation"""
print("\n\nTesting full tile configuration validation...")
# Test case that was failing in the build
tile_m, tile_n, tile_k = 256, 256, 32
warp_m, warp_n, warp_k = 1, 4, 1
warp_tile_m, warp_tile_n, warp_tile_k = 4, 64, 8
print("\nTesting problematic configuration:")
print(f" Tile: {tile_m}x{tile_n}x{tile_k}")
print(f" Warp: {warp_m}x{warp_n}x{warp_k}")
print(f" WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}")
valid = is_tile_config_valid(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
"fp16",
"fp16",
"fp16",
"mem",
)
print(f" Valid: {valid}")
print(" Expected: False (warp tile [4, 64, 8] is not supported for fp16)")
# Test a valid configuration
warp_tile_k = 16 # Change to valid value
print("\nTesting corrected configuration:")
print(f" WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}")
valid = is_tile_config_valid(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
"fp16",
"fp16",
"fp16",
"mem",
)
print(f" Valid: {valid}")
print(" Expected: True")
def main():
"""Run all tests"""
print("=" * 60)
print("GEMM Validation Test Suite")
print("=" * 60)
test_warp_tile_validation()
test_trait_combinations()
test_full_tile_config_validation()
print("\n" + "=" * 60)
print("Test suite completed")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -1,310 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)")
set(GEMM_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM (semicolon-separated)")
set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
option(ENABLE_CCACHE_GEMM "Enable ccache for GEMM ops compilation" OFF)
# Store the directory path for use in functions
set(GEMM_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
# Function to create individual GEMM targets
function(create_individual_gemm_target datatype layout trait tile_config config_json)
# Use the parent scope GEMM_GPU_TARGETS_INDIVIDUAL variable
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping individual GEMM target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets")
return()
endif()
# Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k
# First split by underscore to get three groups
string(REPLACE "_" ";" config_groups ${tile_config})
list(GET config_groups 0 tile_dims) # e.g., 256x256x32
list(GET config_groups 1 warp_dims) # e.g., 4x1x1
list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16
# Parse tile dimensions
string(REPLACE "x" ";" tile_parts ${tile_dims})
list(GET tile_parts 0 tile_m)
list(GET tile_parts 1 tile_n)
list(GET tile_parts 2 tile_k)
# Parse warp dimensions
string(REPLACE "x" ";" warp_parts ${warp_dims})
list(GET warp_parts 0 warp_m)
list(GET warp_parts 1 warp_n)
list(GET warp_parts 2 warp_k)
# Parse warp tile dimensions
string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims})
list(GET warp_tile_parts 0 warp_tile_m)
list(GET warp_tile_parts 1 warp_tile_n)
list(GET warp_tile_parts 2 warp_tile_k)
set(target_name "benchmark_gemm_${datatype}_${layout}_${trait}_${tile_config}")
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Generate the single instance header for this kernel
set(instance_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
# Add custom command to generate the header file at build time
add_custom_command(
OUTPUT ${instance_header}
COMMAND ${Python3_EXECUTABLE} ${GEMM_SOURCE_DIR}/gemm_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${config_json}
--gen_single
--kernel_name "gemm_${datatype}_${layout}_${trait}_${tile_config}"
--tile_config "${tile_config}"
--trait_combo "${trait}"
--gpu_target "${GEMM_GPU_TARGETS_INDIVIDUAL}"
DEPENDS ${GEMM_SOURCE_DIR}/gemm_instance_builder.py ${config_json}
COMMENT "Generating ${instance_header}"
)
# Create the executable
add_executable(${target_name}
# to save build time, exclude the target from "all" target of "gemm" directory and its ancestors
EXCLUDE_FROM_ALL
${GEMM_SOURCE_DIR}/gemm_benchmark_single.cpp
${instance_header}
)
# Set GPU architectures
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS_INDIVIDUAL})
# Set compile definitions
target_compile_definitions(${target_name} PRIVATE
GEMM_SINGLE_INSTANCE_HPP="${instance_header}"
)
# Include directories
target_include_directories(${target_name} PRIVATE
${GEMM_SOURCE_DIR}
${working_path}
)
# Compile options
target_compile_options(${target_name} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
-include ${instance_header}
)
# Add to collection targets
add_dependencies(benchmark_gemm_all ${target_name})
add_dependencies(benchmark_gemm_${datatype} ${target_name})
add_dependencies(benchmark_gemm_${layout} ${target_name})
add_dependencies(benchmark_gemm_${datatype}_${layout} ${target_name})
# Add to trait-specific targets
string(REPLACE "_" ";" trait_parts ${trait})
list(GET trait_parts 0 pipeline)
list(GET trait_parts 1 epilogue)
list(GET trait_parts 2 scheduler)
add_dependencies(benchmark_gemm_${pipeline}_pipeline ${target_name})
add_dependencies(benchmark_gemm_${epilogue}_epilogue ${target_name})
add_dependencies(benchmark_gemm_${scheduler}_scheduler ${target_name})
endfunction()
# Function to build individual GEMM targets
function(build_individual_gemm_targets datatype layout)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Choose config file
# Priority order:
# 1. Environment variable GEMM_CONFIG_FILE
# 2. CMake variable GEMM_CONFIG_FILE
# 3. Default based on layout
# Check environment variable first
if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "")
set(config_filename "$ENV{GEMM_CONFIG_FILE}")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
message(VERBOSE " Using config from environment variable: ${config_filename}")
elseif(NOT "${GEMM_CONFIG_FILE}" STREQUAL "")
# Use CMake variable if set
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_CONFIG_FILE}")
message(VERBOSE " Using custom config: ${GEMM_CONFIG_FILE}")
else()
# Use default config for all layouts
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
message(VERBOSE " Using default config for layout ${layout}")
endif()
# Check if config file exists
if(NOT EXISTS ${json_blob})
message(FATAL_ERROR "Config file not found: ${json_blob}")
endif()
# Determine number of workers for parallel generation
if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL})
set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL})
else()
# Use processor count but limit to avoid memory issues
cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES)
math(EXPR num_workers "${num_cores}")
if(num_workers GREATER 8)
set(num_workers 8)
endif()
endif()
# Generate individual kernel files using parallel version
message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
message(VERBOSE " Working path: ${working_path}")
message(VERBOSE " Config file: ${json_blob}")
message(VERBOSE " Python executable: ${Python3_EXECUTABLE}")
message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py")
# Create working directory first
file(MAKE_DIRECTORY ${working_path})
message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--gpu_target ${GEMM_GPU_TARGETS_INDIVIDUAL}
--list_kernels ")
# First, just list the kernels (fast operation)
message(VERBOSE " Listing kernel configurations...")
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--gpu_target ${GEMM_GPU_TARGETS_INDIVIDUAL}
--list_kernels
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
ERROR_VARIABLE list_error
)
if(NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
endif()
# Read kernel count
if(EXISTS ${working_path}/gemm_kernel_count.txt)
file(READ ${working_path}/gemm_kernel_count.txt kernel_count)
string(STRIP "${kernel_count}" kernel_count)
message(VERBOSE " Found ${kernel_count} kernel configurations")
else()
message(FATAL_ERROR "Kernel count file not found")
endif()
# Read kernel list and create targets
if(EXISTS ${working_path}/gemm_kernel_list.txt)
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
foreach(line IN LISTS kernel_lines)
# Parse line: kernel_name|tile_config|trait_combo
string(REPLACE "|" ";" parts "${line}")
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
# Create individual target
create_individual_gemm_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}")
endforeach()
else()
message(FATAL_ERROR "Kernel list file not found")
endif()
endfunction()
# Main build logic - Only individual builds supported
message(VERBOSE "=== Starting Tile Engine GEMM Configuration ===")
message(VERBOSE "GEMM_DATATYPE: ${GEMM_DATATYPE}")
message(VERBOSE "GEMM_LAYOUT: ${GEMM_LAYOUT}")
message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201
set(GEMM_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
list(APPEND GEMM_GPU_TARGETS_INDIVIDUAL ${target})
message(VERBOSE " Adding GPU target: ${target}")
endif()
endforeach()
# Skip build if no matching targets found
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(VERBOSE "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}")
# Enable parallel compilation optimizations
# Set up job pools for better parallel compilation control
set_property(GLOBAL PROPERTY JOB_POOLS
compile_heavy=4 # Limit heavy compilations to prevent OOM
compile_normal=16 # Allow more parallel normal compilations
)
# Enable compiler cache if available and explicitly requested
# Disabled by default due to permission issues in CI environments
if(ENABLE_CCACHE_GEMM)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
message(VERBOSE "Using ccache for faster compilation")
else()
message(WARNING "ccache requested but not found")
endif()
else()
message(VERBOSE "ccache disabled for GEMM ops (use -DENABLE_CCACHE_GEMM=ON to enable)")
endif()
# Create master collection targets
add_custom_target(benchmark_gemm_all)
# Create datatype collection targets
foreach(dt IN LISTS GEMM_DATATYPE)
add_custom_target(benchmark_gemm_${dt})
endforeach()
# Create layout collection targets
foreach(l IN LISTS GEMM_LAYOUT)
add_custom_target(benchmark_gemm_${l})
endforeach()
# Create combined collection targets
foreach(dt IN LISTS GEMM_DATATYPE)
foreach(l IN LISTS GEMM_LAYOUT)
add_custom_target(benchmark_gemm_${dt}_${l})
endforeach()
endforeach()
# Create trait-based collection targets
# These are common trait components used across all GEMM kernels
set(GEMM_PIPELINES "mem;compv3;compv4")
set(GEMM_EPILOGUES "default;cshuffle")
set(GEMM_SCHEDULERS "intrawave;interwave")
foreach(pipeline IN LISTS GEMM_PIPELINES)
add_custom_target(benchmark_gemm_${pipeline}_pipeline)
endforeach()
foreach(epilogue IN LISTS GEMM_EPILOGUES)
add_custom_target(benchmark_gemm_${epilogue}_epilogue)
endforeach()
foreach(scheduler IN LISTS GEMM_SCHEDULERS)
add_custom_target(benchmark_gemm_${scheduler}_scheduler)
endforeach()
# Build individual targets for each datatype/layout combination
foreach(dt IN LISTS GEMM_DATATYPE)
foreach(l IN LISTS GEMM_LAYOUT)
build_individual_gemm_targets(${dt} ${l})
endforeach()
endforeach()
endif()
add_subdirectory(gemm_universal)
add_subdirectory(gemm_multi_d)
add_subdirectory(gemm_preshuffle)

View File

@@ -1,442 +0,0 @@
# CK Tile Engine GEMM Operations
## Overview
The CK Tile Engine GEMM module provides a comprehensive system for generating, building, and benchmarking GEMM (General Matrix Multiplication) kernels with various configurations. It supports multiple data types, layouts, and optimization strategies. The system has evolved from a monolithic build approach (where all kernels compile into a single executable) to a more flexible individual kernel compilation system, providing better build parallelism and targeted testing capabilities.
## Table of Contents
1. [Build System Architecture](#build-system-architecture)
2. [Build Instructions](#build-instructions)
3. [Running Benchmarks](#running-benchmarks)
4. [Configuration System](#configuration-system)
5. [Scripts and Tools](#scripts-and-tools)
6. [Command Line Options](#command-line-options)
7. [Understanding Kernel Names](#understanding-kernel-names)
8. [Troubleshooting](#troubleshooting)
9. [Performance Tips](#performance-tips)
## Build System Architecture
### Individual Kernel Compilation (New Approach)
The new tile engine benchmark system compiles each kernel configuration into a separate executable. This provides:
- Better build parallelism
- Faster incremental builds
- More targeted testing
- Easier debugging of specific configurations
Each benchmark executable follows the naming pattern:
```
benchmark_gemm_<dtype>_<layout>_<config>_<tile_sizes>
```
### Monolithic Build (Legacy Approach)
The original system compiles all kernels into a single executable (`benchmark_gemm_[Datatype]_[Layout]`), which can then be filtered at runtime using command-line arguments.
## Build Instructions
### Prerequisites
- ROCm installation
- CMake 3.16 or higher
- C++17 compatible compiler
### Basic Build
```bash
# In the root of composable kernel, create build directory
mkdir build && cd build
# Configure with specific datatypes and layouts
# Replace [Arch] with your GPU architecture (e.g., gfx90a, gfx942)
# Replace [Datatype1;Datatype2;...] with datatypes (fp8, bf8, int8, fp16, bf16, fp32, fp64)
# Replace [Layout1;Layout2;...] with layouts (rcr, rrr, crr, ccr)
../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]" -DGEMM_LAYOUT="[Layout1;Layout2]"
# Build specific benchmarks
make benchmark_gemm_[Datatype1]_[Layout1] -j
```
### Configuration Options
The build system supports several configuration options:
#### Using Custom Config Files
```bash
# Method 1: CMake variable (config file must be in configs/ directory)
cmake -DGEMM_CONFIG_FILE=my_custom_config.json ...
# Method 2: Environment variable (takes precedence over CMake variable)
export GEMM_CONFIG_FILE=my_custom_config.json
cmake ...
```
#### Config File Priority Order
1. **Environment variable** `GEMM_CONFIG_FILE` (highest priority)
2. **CMake variable** `GEMM_CONFIG_FILE`
3. **Default config** (default_config.json for all layouts)
**Note**: All custom config files must be placed in the `tile_engine/ops/gemm/configs/` directory.
### Example Build Commands
```bash
# Build for gfx942 with fp8 and fp16 datatypes, rcr layout
mkdir build && cd build
../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16" -DGEMM_LAYOUT="rcr;ccr;rrr;crr"
make benchmark_gemm_fp8_rcr -j
make benchmark_gemm_fp16_rcr -j
```
### Building Individual Kernels
```bash
# Build a specific kernel configuration
make benchmark_gemm_fp8_rcr_compv4_default_intrawave_False_False_False_False_256x256x32_1x4x1_32x32x32
# Build all fp16 benchmarks in parallel
make -j$(nproc) $(make help | grep benchmark_gemm_fp16 | awk '{print $2}')
```
### Rebuilding After Configuration Changes
If you modify the configuration file, you must rebuild:
```bash
rm -rf tile_engine/ && make benchmark_gemm_[Datatype]_[Layout] -j
```
## Running Benchmarks
### Individual Kernel Execution
```bash
cd /path/to/build/directory
./bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 \
-m=512 -n=512 -k=512 -verify=1
```
### Monolithic Executable (Legacy)
```bash
# Run specific pipeline/scheduler/epilogue combination
./bin/benchmark_gemm_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=default
```
### Automated Testing
Use the provided test script to run multiple benchmarks:
```bash
cd /path/to/composable_kernel/tile_engine/ops/gemm
./test_benchmark.sh [build_directory]
```
## Configuration System
### Configuration Files
The system uses JSON configuration files to specify kernel parameters:
- `configs/default_config.json` - Default configurations for various datatypes
- `configs/user_provided_config.json` - User-customizable configurations
### Configuration Structure
```json
{
"tile_config": {
"tile_m": {"values": [256, 128]},
"tile_n": {"values": [256, 128]},
"tile_k": {"values": [64, 32]},
"warp_m": {"values": [2, 4]},
"warp_n": {"values": [2, 1]},
"warp_k": {"values": [1]},
"warp_tile_m": {"values": [32, 16]},
"warp_tile_n": {"values": [32, 16]},
"warp_tile_k": {"values": [16, 32]}
},
"trait_config": {
"pipeline": {"values": ["compv3", "compv4", "mem"]},
"scheduler": {"values": ["intrawave", "interwave"]},
"epilogue": {"values": ["default", "cshuffle"]},
"pad_m": {"values": [false]},
"pad_n": {"values": [false]},
"pad_k": {"values": [false]},
"persistent": {"values": [false]}
}
}
```
## Scripts and Tools
### Python Scripts
#### gemm_instance_builder.py
**Purpose**: Main kernel instance generation script that creates C++ kernel implementations based on configuration files.
**Key Features**:
- Generates individual kernel header files for separate compilation
- Supports multiple data types (fp16, fp8, bf16, fp32, fp64)
- Validates tile configurations for correctness
- Creates CMake integration files
**Usage**:
```bash
python gemm_instance_builder.py \
--working_path ./generated \
--datatype fp16 \
--layout rcr \
--config_json configs/user_provided_config.json \
--gen_all_individual
```
#### gemm_instance_builder_parallel.py
**Purpose**: Parallel version of the instance builder for faster generation of multiple kernel configurations.
**Features**:
- Multi-threaded kernel generation
- Improved performance for large configuration spaces
#### validation_utils.py
**Purpose**: Provides comprehensive validation functions for kernel configurations.
**Key Functions**:
- `is_tile_config_valid()` - Validates tile dimensions and alignments
- `is_trait_combination_valid()` - Checks if pipeline/epilogue/scheduler combinations are supported
- `validate_warp_tile_combination()` - GPU-specific warp tile validation
- `validate_lds_capacity()` - Ensures configurations fit in LDS memory
**Validation Checks**:
- Dimension alignment (tile dimensions must be divisible by warp dimensions)
- LDS capacity constraints
- GPU-specific warp tile support
- Unsupported trait combinations
#### test_validation.py
**Purpose**: Test suite for the validation logic to ensure correctness.
**Usage**:
```bash
python test_validation.py
```
**Tests**:
- Warp tile combination validation
- Trait combination validation
- Full tile configuration validation
#### gemm_benchmark.py
**Purpose**: Python script for running and analyzing GEMM benchmarks.
**Features**:
- Automated benchmark execution
- Performance data collection
- Result analysis and reporting
#### json_config.py
**Purpose**: Configuration file parsing and management.
**Features**:
- JSON configuration loading
- Default configuration handling
- Configuration validation
#### codegen_utils.py
**Purpose**: Utility functions for code generation.
**Features**:
- Template processing
- Code formatting utilities
- File generation helpers
### Shell Scripts
#### test_benchmark.sh
**Purpose**: Automated benchmark testing script that finds and runs all built benchmark executables.
**Features**:
- Automatic build directory detection
- Batch execution of multiple benchmarks
- CSV result collection
- Colored output for easy reading
- Example command generation
**Usage**:
```bash
# Auto-detect build directory
./test_benchmark.sh
# Specify build directory
./test_benchmark.sh /path/to/build/directory
```
**What it does**:
1. Finds all benchmark executables in the build directory
2. Runs each with multiple problem sizes (512, 1024, 2048)
3. Performs GPU verification
4. Saves results to timestamped CSV file
5. Provides summary statistics
## Command Line Options
All benchmark executables support the following options:
### Matrix Dimensions
- `-m=<value>` - M dimension (default: 3840)
- `-n=<value>` - N dimension (default: 4096)
- `-k=<value>` - K dimension (default: 2048)
### Strides
- `-stride_a=<value>` - Stride for matrix A (default: 0, auto-calculated)
- `-stride_b=<value>` - Stride for matrix B (default: 0, auto-calculated)
- `-stride_c=<value>` - Stride for matrix C (default: 0, auto-calculated)
### Verification
- `-verify=<0|1|2>` - Verification mode
- 0: No verification (default)
- 1: CPU verification
- 2: GPU verification
### Performance Testing
- `-warmup=<value>` - Warmup iterations (default: 50)
- `-repeat=<value>` - Benchmark iterations (default: 100)
- `-timer=<true|false>` - Use GPU timer (default: true)
- `-flush_cache=<true|false>` - Flush cache between runs (default: true)
- `-rotating_count=<value>` - Cache rotation count (default: 1000)
### Initialization
- `-init=<0|1|2>` - Tensor initialization method
- 0: Random values [-1, 1] (default)
- 1: Linear sequence (i % 17)
- 2: Constant value (1.0)
### Output Options
- `-log=<true|false>` - Enable verbose logging (default: false)
- `-metric=<0|1|2>` - Performance metric
- 0: Latency in ms (default)
- 1: TFLOPS
- 2: Bandwidth in GB/s
- `-json_output=<true|false>` - JSON format output (default: false)
- `-csv_filename=<filename>` - Save results to CSV
- `-csv_format=<simple|comprehensive>` - CSV format (default: comprehensive)
### Advanced Options
- `-split_k=<value>` - Split-K factor (default: 1)
- `-structured_sparsity=<true|false>` - Enable structured sparsity (default: false)
- `-pipeline=<compv3|compv4|mem>` - Pipeline type (default: compv3)
- `-scheduler=<intrawave|interwave>` - Scheduler type (default: intrawave)
- `-epilogue=<cshuffle|default>` - Epilogue type (default: cshuffle)
- `-pad_m=<true|false>` - Pad M dimension (default: false)
- `-pad_n=<true|false>` - Pad N dimension (default: false)
- `-pad_k=<true|false>` - Pad K dimension (default: false)
- `-persistent=<true|false>` - Use persistent kernel (default: false)
## Understanding Kernel Names
The kernel naming convention encodes the configuration:
```
benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16
^^^^ ^^^ ^^^^^^ ^^^^^^^ ^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^ ^^^^^^^ ^^^^^^^^^
| | | | | | | | |
| | | | | Padding & flags | | Warp tile
| | | | Scheduler | Thread tile
| | | Epilogue Block tile
| | Pipeline
| Layout (Row-Column-Row)
Data type
```
### Components:
- **Data type**: fp16, fp32, bf16, fp8, bf8, int8
- **Layout**: rcr (Row-Column-Row), rrr, crr, ccr
- **Pipeline**: mem, compv3, compv4
- **Epilogue**: default, cshuffle
- **Scheduler**: intrawave, interwave
- **Flags**: pad_m, pad_n, pad_k, persistent (4 boolean flags)
- **Tile sizes**: BlockTile x ThreadTile x WarpTile
## Troubleshooting
### Common Issues
1. **Kernel not found**
- Ensure the specific benchmark executable is built
- Check the build directory bin/ folder
2. **Verification failures**
- Try GPU verification (-verify=2) which may be more accurate
- Check data type compatibility
- Verify stride calculations
3. **Build failures**
- Check GPU architecture compatibility
- Ensure ROCm is properly installed
- Verify configuration file syntax
4. **Performance variations**
- Increase warmup iterations
- Disable CPU frequency scaling
- Use GPU timer for accurate measurements
### Debug Options
Enable verbose logging:
```bash
./bin/benchmark_gemm_... -log=true -verify=1
```
Test validation logic:
```bash
python test_validation.py
```
## Performance Tips
1. **Optimal Problem Sizes**: Use sizes that are multiples of tile dimensions
2. **Warmup**: Use at least 50-100 warmup iterations
3. **GPU Timer**: Always use `-timer=true` for accurate measurements
4. **Cache Management**: Enable cache flushing for consistent results
5. **Thread Affinity**: Set CPU affinity to reduce variation
## Integration Examples
### Python Integration
```python
import subprocess
import json
# Run benchmark with JSON output
result = subprocess.run([
'./bin/benchmark_gemm_fp16_rcr_...',
'-m=1024', '-n=1024', '-k=1024',
'-json_output=true'
], capture_output=True, text=True)
# Parse results
data = json.loads(result.stdout)
print(f"Performance: {data['tflops']} TFLOPS")
```
### Batch Testing Script
```bash
#!/bin/bash
SIZES="512 1024 2048 4096"
for size in $SIZES; do
echo "Testing ${size}x${size}x${size}"
./bin/benchmark_gemm_... -m=$size -n=$size -k=$size \
-verify=2 -csv_filename=results.csv
done
```
## Contributing
When adding new features or configurations:
1. Update validation logic in `validation_utils.py`
2. Add tests to `test_validation.py`
3. Update configuration examples
4. Document new command-line options
For more information about the Composable Kernel project, visit the main repository documentation.

View File

@@ -1,41 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
// Helper function to determine if a layout is row-major
template <typename Layout>
constexpr auto is_row_major(Layout)
{
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// Structure to hold kernel traits for dispatcher
struct KernelTraits
{
std::string pipeline; // compv3, compv4, mem
std::string scheduler; // intrawave, interwave
std::string epilogue; // cshuffle, default
bool pad_m;
bool pad_n;
bool pad_k;
bool persistent;
// Constructor with defaults
KernelTraits()
: pipeline("compv3"),
scheduler("intrawave"),
epilogue("cshuffle"),
pad_m(false),
pad_n(false),
pad_k(false),
persistent(false)
{
}
};

File diff suppressed because it is too large Load Diff

View File

@@ -70,7 +70,6 @@ function(create_individual_gemm_multi_d_target datatype layout trait tile_config
# Create the executable
add_executable(${target_name}
# to save build time, exclude the target from "all" target of "gemm_multi_d" directory and its ancestors
EXCLUDE_FROM_ALL
${GEMM_MULTI_D_SOURCE_DIR}/gemm_multi_d_benchmark_single.cpp
${instance_header}

View File

@@ -1,4 +1,3 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -80,12 +80,12 @@ void benchmark_single(const ck_tile::ArgParser& arg_parser)
{
// Use DataTypeTraits to get the actual type names from the generated header
// The generated header defines ADataType, BDataType, AccDataType, CDataType
std::string dtype_a = ck_tile::DataTypeTraits<ADataType>::name;
std::string dtype_b = ck_tile::DataTypeTraits<BDataType>::name;
std::string dtype_acc = ck_tile::DataTypeTraits<AccDataType>::name;
std::string dtype_c = ck_tile::DataTypeTraits<CDataType>::name;
std::string dtype_d0 = ck_tile::DataTypeTraits<D0DataType>::name;
std::string dtype_d1 = ck_tile::DataTypeTraits<D1DataType>::name;
std::string dtype_a = DataTypeTraits<ADataType>::name;
std::string dtype_b = DataTypeTraits<BDataType>::name;
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
std::string dtype_c = DataTypeTraits<CDataType>::name;
std::string dtype_d0 = DataTypeTraits<D0DataType>::name;
std::string dtype_d1 = DataTypeTraits<D1DataType>::name;
// Layout names from the layout types
std::string layout_a = ALayout::name;

View File

@@ -0,0 +1,100 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
//[TODO] This can be moved to commons
// DataTypeTraits for all supported types
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <>
struct DataTypeTraits<ck_tile::int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
// Helper function to determine if a layout is row-major
template <typename Layout>
constexpr auto is_row_major(Layout)
{
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// Structure to hold kernel traits for dispatcher
struct KernelTraits
{
std::string pipeline; // compv3, compv4, mem
std::string scheduler; // intrawave, interwave
std::string epilogue; // cshuffle, default
bool pad_m;
bool pad_n;
bool pad_k;
bool persistent;
// Constructor with defaults
KernelTraits()
: pipeline("compv3"),
scheduler("intrawave"),
epilogue("cshuffle"),
pad_m(false),
pad_n(false),
pad_k(false),
persistent(false)
{
}
};

View File

@@ -0,0 +1,330 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import os
import argparse
import importlib.util
import multiprocessing
import concurrent.futures
def _import_gemm_kernel_builder():
"""Import validation utilities from commons directory."""
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
# Load the module dynamically
spec = importlib.util.spec_from_file_location(
"gemm_instance_builder",
os.path.join(parent_dir, "gemm_instance_builder.py"),
)
gemm_builder_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(gemm_builder_module)
return gemm_builder_module.GemmKernelBuilder
GemmKernelBuilder = _import_gemm_kernel_builder()
class GemmMultiDKernelBuilder(GemmKernelBuilder):
def __init__(
self,
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
elementwise_function,
config_json=None,
):
super().__init__(
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
)
self.elementwise_function = elementwise_function
def _generate_all_individual(self, num_workers=None):
"""Generate individual kernel files for separate compilation with parallel processing"""
if num_workers is None:
num_workers = min(
multiprocessing.cpu_count(), 8
) # Limit to avoid memory issues
tile_configs = self._get_tile_configs()
trait_combos = self._generate_trait_combinations()
# Prepare work items for parallel processing
work_items = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
work_items.append(
(
tile_config,
trait_combo,
self.kernel_name_prefix,
self.working_path,
self.gpu_target,
self.datatype,
self.layout,
self.elementwise_function,
self.config_json,
)
)
print(
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
)
print(f" Tile configs: {len(tile_configs)}")
print(f" Trait combinations: {len(trait_combos)}")
print(f" Total kernels: {len(work_items)}")
# Show first few work items for debugging
if work_items:
print(" First work item example:")
tile_config, trait_combo = work_items[0][:2]
print(f" Tile config: {tile_config}")
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
# Process work items in parallel
kernel_list = []
completed = 0
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
# Submit all work items
print(f" Submitting {len(work_items)} tasks to executor...")
future_to_item = {
executor.submit(_generate_single_kernel_individual, item): item
for item in work_items
}
print(" All tasks submitted, waiting for completion...")
# Collect results with progress reporting
for future in concurrent.futures.as_completed(future_to_item):
completed += 1
if completed % 100 == 0 or completed == len(work_items):
print(
f" Progress: {completed}/{len(work_items)} kernels generated"
)
try:
result = future.result()
if result:
kernel_list.append(result)
except Exception as exc:
item = future_to_item[future]
print(f"Kernel generation failed for {item}: {exc}")
# Sort kernel list for consistent ordering
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
# Generate CMake include file for individual targets
self._generate_cmake_individual_targets(kernel_list)
print(
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
)
def _generate_single_kernel_individual(work_item):
"""Worker function to generate a single individual kernel file"""
(
tile_config,
trait_combo,
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
elementwise_function,
config_json,
) = work_item
# Create a temporary builder instance for this worker
builder = GemmMultiDKernelBuilder(
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
elementwise_function,
config_json,
)
try:
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
)
# Create simplified filename without the "gemm_multi_d_" prefix
# Remove "gemm_multi_d_" from the beginning of kernel_name for the filename
simplified_name = kernel_name
if simplified_name.startswith("gemm_multi_d_"):
simplified_name = simplified_name[
len(kernel_name_prefix) + 1 :
] # Remove "gemm_multi_d_" prefix
# Write individual header file
header_file = working_path / f"gemm_multi_d_single_{simplified_name}.hpp"
with open(header_file, "w") as f:
f.write(instance_code)
return (kernel_name, trait_combo, tile_config)
except Exception as e:
print(f"Error generating individual kernel: {e}")
return None
def main():
parser = argparse.ArgumentParser(
description="GEMM Multi D kernel instance builder with parallel support"
)
parser.add_argument("--working_path", required=True, help="Working directory path")
parser.add_argument("--gpu_target", required=True, help="GPU target architecture")
parser.add_argument(
"--datatype",
required=True,
choices=["fp16"],
help="Data type",
)
parser.add_argument(
"--layout",
required=True,
choices=["rcrr", "rrrr", "ccrr", "crrr"],
help="Matrix layout",
)
parser.add_argument(
"--elementwise_function",
required=True,
help="Specify what element wise function for D, e.g. mul, add, passthrough",
)
parser.add_argument("--config_json", help="Configuration JSON file")
parser.add_argument(
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
)
parser.add_argument(
"--gen_all_individual",
action="store_true",
help="Generate individual kernel files",
)
parser.add_argument(
"--gen_single", action="store_true", help="Generate a single kernel file"
)
parser.add_argument("--kernel_name", help="Kernel name for single generation")
parser.add_argument(
"--tile_config", help="Tile configuration string for single generation"
)
parser.add_argument(
"--trait_combo", help="Trait combination string for single generation"
)
parser.add_argument(
"--list_kernels",
action="store_true",
help="List kernel configurations without generating files",
)
args = parser.parse_args()
assert args.datatype in ["fp16"], (
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16])"
)
layout_parts = args.layout.lower()
assert len(layout_parts) == 4, (
f"Invalid layout string: {args.layout} (must be 4 characters like 'rcrr' where r stands for row major and c stands for column major)"
)
assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], (
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)"
)
assert layout_parts[2] == "r" and layout_parts[3] == "r", (
f"Invalid matrix_c or d dimension in layout: {layout_parts[2]} and {layout_parts[3]} (must be 'r' only as currently we are supporting only row major)"
)
# Elementwise function name validation
elementwise_function = args.elementwise_function.lower()
valid_functions = ["mul", "add", "passthrough"]
if elementwise_function not in valid_functions:
raise ValueError(
f"Invalid elementwise function: {elementwise_function}. "
f"Valid options are: {', '.join(valid_functions)}"
)
# Set the function name based on the elementwise function
if elementwise_function == "mul":
function_name = "MultiDMultiply"
elif elementwise_function == "add":
function_name = "MultiDAdd"
elif elementwise_function == "passthrough":
function_name = "PassThrough"
args.elementwise_function = function_name
# Create builder
kernel_name_prefix = "gemm_multi_d"
builder = GemmMultiDKernelBuilder(
kernel_name_prefix,
args.working_path,
args.gpu_target,
args.datatype,
args.layout,
args.elementwise_function,
args.config_json,
)
if args.list_kernels:
builder._list_kernels()
elif args.gen_single:
# Generate a single kernel file
if not args.kernel_name or not args.tile_config or not args.trait_combo:
parser.error(
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
)
# Parse tile config
tile_parts = args.tile_config.split("_")
tile_dims = tile_parts[0].split("x")
warp_dims = tile_parts[1].split("x")
warp_tile_dims = tile_parts[2].split("x")
tile_config = {
"tile_m": int(tile_dims[0]),
"tile_n": int(tile_dims[1]),
"tile_k": int(tile_dims[2]),
"warp_m": int(warp_dims[0]),
"warp_n": int(warp_dims[1]),
"warp_k": int(warp_dims[2]),
"warp_tile_m": int(warp_tile_dims[0]),
"warp_tile_n": int(warp_tile_dims[1]),
"warp_tile_k": int(warp_tile_dims[2]),
}
# Parse trait combo
trait_parts = args.trait_combo.split("_")
trait_combo = (
trait_parts[0], # pipeline
trait_parts[1], # epilogue
trait_parts[2], # scheduler
trait_parts[3], # pad_m
trait_parts[4], # pad_n
trait_parts[5], # pad_k
trait_parts[6], # persistent
)
# Generate the kernel
builder._generate_kernel_instance(
tile_config,
trait_combo,
)
elif args.gen_all_individual:
# Generate all individual kernel files
builder._generate_all_individual(args.num_workers)
else:
parser.error(
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
)
if __name__ == "__main__":
main()

View File

@@ -67,7 +67,6 @@ function(create_individual_gemm_preshuffle_target datatype layout trait tile_con
# Create the executable
add_executable(${target_name}
# to save build time, exclude the target from "all" target of "gemm_preshuffle" directory and its ancestors
EXCLUDE_FROM_ALL
${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_benchmark_single.cpp
${instance_header}

View File

@@ -1,4 +1,3 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -11,12 +11,12 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/common/utils.hpp"
#include "gemm_preshuffle_profiler.hpp"
#include "gemm_preshuffle_common.hpp"
// The kernel header is included via the compile command line with -include flag
// It defines SelectedKernel struct and KERNEL_NAME
// DataTypeTraits are now defined in gemm_common.hpp
// Create argument parser
inline auto create_args(int argc, char* argv[])
@@ -77,12 +77,12 @@ inline auto create_args(int argc, char* argv[])
void benchmark_single(const ck_tile::ArgParser& arg_parser)
{
// Use ck_tile::DataTypeTraits to get the actual type names from the generated header
// Use DataTypeTraits to get the actual type names from the generated header
// The generated header defines ADataType, BDataType, AccDataType, CDataType
std::string dtype_a = ck_tile::DataTypeTraits<ADataType>::name;
std::string dtype_b = ck_tile::DataTypeTraits<BDataType>::name;
std::string dtype_acc = ck_tile::DataTypeTraits<AccDataType>::name;
std::string dtype_c = ck_tile::DataTypeTraits<CDataType>::name;
std::string dtype_a = DataTypeTraits<ADataType>::name;
std::string dtype_b = DataTypeTraits<BDataType>::name;
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
std::string dtype_c = DataTypeTraits<CDataType>::name;
// Layout names from the layout types
std::string layout_a = ALayout::name;

View File

@@ -0,0 +1,181 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
//[TODO] This can be moved to commons
// DataTypeTraits for all supported types
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <>
struct DataTypeTraits<ck_tile::int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
// Helper function to determine if a layout is row-major
template <typename Layout>
constexpr auto is_row_major(Layout)
{
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// Structure to hold kernel traits for dispatcher
struct KernelTraits
{
std::string pipeline; // preshufflev2
std::string scheduler; // intrawave, interwave, default
std::string epilogue; // cshuffle, default
bool pad_m;
bool pad_n;
bool pad_k;
bool persistent;
// Constructor with defaults
KernelTraits()
: pipeline("preshufflev2"),
scheduler("default"),
epilogue("default"),
pad_m(false),
pad_n(false),
pad_k(false),
persistent(false)
{
}
};
// Helper to extract traits from kernel name
inline KernelTraits extract_traits_from_name(const std::string& kernel_name)
{
KernelTraits traits;
// Extract pipeline
if(kernel_name.find("preshufflev2") != std::string::npos)
{
traits.pipeline = "preshufflev2";
}
// Extract scheduler
if(kernel_name.find("interwave") != std::string::npos)
{
traits.scheduler = "interwave";
}
else if(kernel_name.find("intrawave") != std::string::npos)
{
traits.scheduler = "intrawave";
}
else
{
traits.scheduler = "default";
}
// Extract epilogue
if(kernel_name.find("default") != std::string::npos &&
kernel_name.find("default_") == std::string::npos)
{
traits.epilogue = "default";
}
else
{
traits.epilogue = "cshuffle";
}
// Padding flags would need to be extracted from the kernel configuration
// For now, we'll leave them as false
return traits;
}
template <typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int divisor = N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view(
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile,
ck_tile::index_t N_Tile,
ck_tile::index_t N_Warp)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int divisor = N_Warp_Tile == 32 ? 2 : 4;
int NRepeat = N_Tile / N_Warp_Tile / N_Warp;
ck_tile::HostTensor<T> t_view({n_ / N_Tile,
N_Warp,
N_Warp_Tile,
NRepeat,
k_ / K_Warp_Tile,
divisor,
K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}

View File

@@ -0,0 +1,300 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import os
import argparse
import importlib.util
import multiprocessing
import concurrent.futures
def _import_gemm_kernel_builder():
"""Import validation utilities from commons directory."""
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
# Load the module dynamically
spec = importlib.util.spec_from_file_location(
"gemm_instance_builder",
os.path.join(parent_dir, "gemm_instance_builder.py"),
)
gemm_builder_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(gemm_builder_module)
return gemm_builder_module.GemmKernelBuilder
GemmKernelBuilder = _import_gemm_kernel_builder()
class GemmPreshuffleKernelBuilder(GemmKernelBuilder):
def __init__(
self,
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
config_json=None,
):
super().__init__(
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
)
def _generate_all_individual(self, num_workers=None):
"""Generate individual kernel files for separate compilation with parallel processing"""
if num_workers is None:
num_workers = min(
multiprocessing.cpu_count(), 8
) # Limit to avoid memory issues
tile_configs = self._get_tile_configs()
trait_combos = self._generate_trait_combinations()
# Prepare work items for parallel processing
work_items = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
work_items.append(
(
tile_config,
trait_combo,
self.kernel_name_prefix,
self.working_path,
self.gpu_target,
self.datatype,
self.layout,
self.config_json,
)
)
print(
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
)
print(f" Tile configs: {len(tile_configs)}")
print(f" Trait combinations: {len(trait_combos)}")
print(f" Total kernels: {len(work_items)}")
# Show first few work items for debugging
if work_items:
print(" First work item example:")
tile_config, trait_combo = work_items[0][:2]
print(f" Tile config: {tile_config}")
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
# Process work items in parallel
kernel_list = []
completed = 0
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
# Submit all work items
print(f" Submitting {len(work_items)} tasks to executor...")
future_to_item = {
executor.submit(_generate_single_kernel_individual, item): item
for item in work_items
}
print(" All tasks submitted, waiting for completion...")
# Collect results with progress reporting
for future in concurrent.futures.as_completed(future_to_item):
completed += 1
if completed % 100 == 0 or completed == len(work_items):
print(
f" Progress: {completed}/{len(work_items)} kernels generated"
)
try:
result = future.result()
if result:
kernel_list.append(result)
except Exception as exc:
item = future_to_item[future]
print(f"Kernel generation failed for {item}: {exc}")
# Sort kernel list for consistent ordering
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
# Generate CMake include file for individual targets
self._generate_cmake_individual_targets(kernel_list)
print(
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
)
def _generate_single_kernel_individual(work_item):
"""Worker function to generate a single individual kernel file"""
(
tile_config,
trait_combo,
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
config_json,
) = work_item
# Create a temporary builder instance for this worker
builder = GemmPreshuffleKernelBuilder(
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
)
try:
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
)
# Create simplified filename without the "gemm_preshuffle_" prefix
# Remove "gemm_preshuffle_" from the beginning of kernel_name for the filename
simplified_name = kernel_name
if simplified_name.startswith("gemm_preshuffle_"):
simplified_name = simplified_name[
len(kernel_name_prefix) + 1 :
] # Remove "gemm_preshuffle_" prefix
# Write individual header file
header_file = working_path / f"gemm_preshuffle_single_{simplified_name}.hpp"
with open(header_file, "w") as f:
f.write(instance_code)
return (kernel_name, trait_combo, tile_config)
except Exception as e:
print(f"Error generating individual kernel: {e}")
return None
def main():
parser = argparse.ArgumentParser(
description="GEMM kernel instance builder with parallel support"
)
parser.add_argument("--working_path", required=True, help="Working directory path")
parser.add_argument(
"--gpu_target",
required=True,
help="GPU target architecture",
)
parser.add_argument(
"--datatype",
required=True,
choices=["fp16", "fp8", "bf16", "bf8"],
help="Data type",
)
parser.add_argument(
"--layout",
required=True,
choices=["rcr"],
help="Matrix layout",
)
parser.add_argument("--config_json", required=True, help="Configuration JSON file")
parser.add_argument(
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
)
parser.add_argument(
"--gen_all_individual",
action="store_true",
help="Generate individual kernel files",
)
parser.add_argument(
"--gen_single", action="store_true", help="Generate a single kernel file"
)
parser.add_argument("--kernel_name", help="Kernel name for single generation")
parser.add_argument(
"--tile_config", help="Tile configuration string for single generation"
)
parser.add_argument(
"--trait_combo", help="Trait combination string for single generation"
)
parser.add_argument(
"--list_kernels",
action="store_true",
help="List kernel configurations without generating files",
)
args = parser.parse_args()
assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], (
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])"
)
layout_parts = args.layout.lower()
assert len(layout_parts) == 3, (
f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
)
assert layout_parts[0] in ["r"] and layout_parts[1] in ["c"], (
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a must be 'r' for row major and matrix_b must be 'c' for column major as it is the only supported layout for preshuffle)"
)
assert layout_parts[2] == "r", (
f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
)
# Create builder
kernel_name_prefix = "gemm_preshuffle"
builder = GemmPreshuffleKernelBuilder(
kernel_name_prefix,
args.working_path,
args.gpu_target,
args.datatype,
args.layout,
args.config_json,
)
if args.list_kernels:
# Fast listing mode - just write kernel list without generating files
builder._list_kernels()
elif args.gen_single:
# Generate a single kernel file
if not args.kernel_name or not args.tile_config or not args.trait_combo:
parser.error(
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
)
# Parse tile config
tile_parts = args.tile_config.split("_")
tile_dims = tile_parts[0].split("x")
warp_dims = tile_parts[1].split("x")
warp_tile_dims = tile_parts[2].split("x")
tile_config = {
"tile_m": int(tile_dims[0]),
"tile_n": int(tile_dims[1]),
"tile_k": int(tile_dims[2]),
"warp_m": int(warp_dims[0]),
"warp_n": int(warp_dims[1]),
"warp_k": int(warp_dims[2]),
"warp_tile_m": int(warp_tile_dims[0]),
"warp_tile_n": int(warp_tile_dims[1]),
"warp_tile_k": int(warp_tile_dims[2]),
}
# Parse trait combo
trait_parts = args.trait_combo.split("_")
trait_combo = (
trait_parts[0], # pipeline
trait_parts[1], # epilogue
trait_parts[2], # scheduler
trait_parts[3] == "True", # pad_m
trait_parts[4] == "True", # pad_n
trait_parts[5] == "True", # pad_k
trait_parts[6] == "True", # persistent
)
# Generate the kernel
builder._generate_kernel_instance(
tile_config,
trait_combo,
)
elif args.gen_all_individual:
# Generate all individual kernel files
builder._generate_all_individual(args.num_workers)
pass
else:
parser.error(
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
)
if __name__ == "__main__":
main()

View File

@@ -111,30 +111,21 @@ class GemmProfiler
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
struct GemmConfig
{
ck_tile::index_t N_Warp_Tile;
ck_tile::index_t K_Warp_Tile;
ck_tile::index_t N_Tile;
ck_tile::index_t N_Warp;
};
for(const auto& callable : callables)
{
GemmConfig gemmConfig = {};
gemmConfig.N_Warp_Tile = std::get<1>(config.warp_tile_dims);
gemmConfig.K_Warp_Tile = std::get<2>(config.warp_tile_dims);
gemmConfig.N_Tile = std::get<1>(config.tile_dims);
gemmConfig.N_Warp = std::get<1>(config.warp_dims);
ck_tile::index_t N_Warp_Tile = std::get<1>(config.warp_tile_dims);
ck_tile::index_t K_Warp_Tile = std::get<2>(config.warp_tile_dims);
ck_tile::index_t N_Tile = std::get<1>(config.tile_dims);
ck_tile::index_t N_Warp = std::get<1>(config.warp_dims);
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
if(config.permuteN)
{
return ck_tile::shuffle_b_permuteN(b_k_n, gemmConfig);
return shuffle_b_permuteN(b_k_n, N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp);
}
else
{
return ck_tile::shuffle_b(b_k_n, gemmConfig);
return shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile);
}
}();

View File

@@ -0,0 +1,309 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(GEMM_UNIVERSAL_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM Universal (semicolon-separated)")
set(GEMM_UNIVERSAL_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM Universal (semicolon-separated)")
set(GEMM_UNIVERSAL_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
option(ENABLE_CCACHE_GEMM_UNIVERSAL "Enable ccache for GEMM Universal ops compilation" OFF)
# Store the directory path for use in functions
set(GEMM_UNIVERSAL_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
# Function to create individual GEMM Universal targets
function(create_individual_gemm_universal_target datatype layout trait tile_config config_json)
# Use the parent scope GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL variable
if(NOT GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping individual GEMM Universal target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets")
return()
endif()
# Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k
# First split by underscore to get three groups
string(REPLACE "_" ";" config_groups ${tile_config})
list(GET config_groups 0 tile_dims) # e.g., 256x256x32
list(GET config_groups 1 warp_dims) # e.g., 4x1x1
list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16
# Parse tile dimensions
string(REPLACE "x" ";" tile_parts ${tile_dims})
list(GET tile_parts 0 tile_m)
list(GET tile_parts 1 tile_n)
list(GET tile_parts 2 tile_k)
# Parse warp dimensions
string(REPLACE "x" ";" warp_parts ${warp_dims})
list(GET warp_parts 0 warp_m)
list(GET warp_parts 1 warp_n)
list(GET warp_parts 2 warp_k)
# Parse warp tile dimensions
string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims})
list(GET warp_tile_parts 0 warp_tile_m)
list(GET warp_tile_parts 1 warp_tile_n)
list(GET warp_tile_parts 2 warp_tile_k)
set(target_name "benchmark_gemm_universal_${datatype}_${layout}_${trait}_${tile_config}")
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Generate the single instance header for this kernel
set(instance_header "${working_path}/gemm_universal_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
# Add custom command to generate the header file at build time
add_custom_command(
OUTPUT ${instance_header}
COMMAND ${Python3_EXECUTABLE} ${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_universal_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${config_json}
--gen_single
--kernel_name "gemm_universal_${datatype}_${layout}_${trait}_${tile_config}"
--tile_config "${tile_config}"
--trait_combo "${trait}"
--gpu_target "${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}"
DEPENDS ${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_universal_instance_builder.py ${config_json}
COMMENT "Generating ${instance_header}"
)
# Create the executable
add_executable(${target_name}
EXCLUDE_FROM_ALL
${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_benchmark_single.cpp
${instance_header}
)
# Set GPU architectures
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL})
# Set compile definitions
target_compile_definitions(${target_name} PRIVATE
GEMM_UNIVERSAL_SINGLE_INSTANCE_HPP="${instance_header}"
)
# Include directories
target_include_directories(${target_name} PRIVATE
${GEMM_UNIVERSAL_SOURCE_DIR}
${working_path}
)
# Compile options
target_compile_options(${target_name} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
-include ${instance_header}
)
# Add to collection targets
add_dependencies(benchmark_gemm_universal_all ${target_name})
add_dependencies(benchmark_gemm_universal_${datatype} ${target_name})
add_dependencies(benchmark_gemm_universal_${layout} ${target_name})
add_dependencies(benchmark_gemm_universal_${datatype}_${layout} ${target_name})
# Add to trait-specific targets
string(REPLACE "_" ";" trait_parts ${trait})
list(GET trait_parts 0 pipeline)
list(GET trait_parts 1 epilogue)
list(GET trait_parts 2 scheduler)
add_dependencies(benchmark_gemm_universal_${pipeline}_pipeline ${target_name})
add_dependencies(benchmark_gemm_universal_${epilogue}_epilogue ${target_name})
add_dependencies(benchmark_gemm_universal_${scheduler}_scheduler ${target_name})
endfunction()
# Function to build individual GEMM Universal targets
function(build_individual_gemm_universal_targets datatype layout)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Choose config file
# Priority order:
# 1. Environment variable GEMM_UNIVERSAL_CONFIG_FILE
# 2. CMake variable GEMM_UNIVERSAL_CONFIG_FILE
# 3. Default based on layout
# Check environment variable first
if(DEFINED ENV{GEMM_UNIVERSAL_CONFIG_FILE} AND NOT "$ENV{GEMM_UNIVERSAL_CONFIG_FILE}" STREQUAL "")
set(config_filename "$ENV{GEMM_UNIVERSAL_CONFIG_FILE}")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
message(VERBOSE " Using config from environment variable: ${config_filename}")
elseif(NOT "${GEMM_UNIVERSAL_CONFIG_FILE}" STREQUAL "")
# Use CMake variable if set
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_UNIVERSAL_CONFIG_FILE}")
message(VERBOSE " Using custom config: ${GEMM_UNIVERSAL_CONFIG_FILE}")
else()
# Use default config for all layouts
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
message(VERBOSE " Using default config for layout ${layout}")
endif()
# Check if config file exists
if(NOT EXISTS ${json_blob})
message(FATAL_ERROR "Config file not found: ${json_blob}")
endif()
# Determine number of workers for parallel generation
if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL})
set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL})
else()
# Use processor count but limit to avoid memory issues
cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES)
math(EXPR num_workers "${num_cores}")
if(num_workers GREATER 8)
set(num_workers 8)
endif()
endif()
# Generate individual kernel files using parallel version
message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
message(VERBOSE " Working path: ${working_path}")
message(VERBOSE " Config file: ${json_blob}")
message(VERBOSE " Python executable: ${Python3_EXECUTABLE}")
message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py")
# Create working directory first
file(MAKE_DIRECTORY ${working_path})
message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--gpu_target ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}
--list_kernels ")
# First, just list the kernels (fast operation)
message(VERBOSE " Listing kernel configurations...")
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--gpu_target ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}
--list_kernels
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
ERROR_VARIABLE list_error
)
if(NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
endif()
# Read kernel count
if(EXISTS ${working_path}/gemm_universal_kernel_count.txt)
file(READ ${working_path}/gemm_universal_kernel_count.txt kernel_count)
string(STRIP "${kernel_count}" kernel_count)
message(VERBOSE " Found ${kernel_count} kernel configurations")
else()
message(FATAL_ERROR "Kernel count file not found")
endif()
# Read kernel list and create targets
if(EXISTS ${working_path}/gemm_universal_kernel_list.txt)
file(STRINGS ${working_path}/gemm_universal_kernel_list.txt kernel_lines)
foreach(line IN LISTS kernel_lines)
# Parse line: kernel_name|tile_config|trait_combo
string(REPLACE "|" ";" parts "${line}")
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
# Create individual target
create_individual_gemm_universal_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}")
endforeach()
else()
message(FATAL_ERROR "Kernel list file not found")
endif()
endfunction()
# Main build logic - Only individual builds supported
message(VERBOSE "=== Starting Tile Engine GEMM Universal Configuration ===")
message(VERBOSE "GEMM_UNIVERSAL_DATATYPE: ${GEMM_UNIVERSAL_DATATYPE}")
message(VERBOSE "GEMM_UNIVERSAL_LAYOUT: ${GEMM_UNIVERSAL_LAYOUT}")
message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201
set(GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
list(APPEND GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL ${target})
message(VERBOSE " Adding GPU target: ${target}")
endif()
endforeach()
# Skip build if no matching targets found
if(NOT GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine GEMM Universal build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(VERBOSE "Building individual GEMM Universal targets for GPU targets: ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}")
# Enable parallel compilation optimizations
# Set up job pools for better parallel compilation control
set_property(GLOBAL PROPERTY JOB_POOLS
compile_heavy=4 # Limit heavy compilations to prevent OOM
compile_normal=16 # Allow more parallel normal compilations
)
# Enable compiler cache if available and explicitly requested
# Disabled by default due to permission issues in CI environments
if(ENABLE_CCACHE_GEMM_UNIVERSAL)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
message(VERBOSE "Using ccache for faster compilation")
else()
message(WARNING "ccache requested but not found")
endif()
else()
message(VERBOSE "ccache disabled for GEMM Universal ops (use -DENABLE_CCACHE_GEMM_UNIVERSAL=ON to enable)")
endif()
# Create master collection targets
add_custom_target(benchmark_gemm_universal_all)
# Create datatype collection targets
foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE)
add_custom_target(benchmark_gemm_universal_${dt})
endforeach()
# Create layout collection targets
foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT)
add_custom_target(benchmark_gemm_universal_${l})
endforeach()
# Create combined collection targets
foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE)
foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT)
add_custom_target(benchmark_gemm_universal_${dt}_${l})
endforeach()
endforeach()
# Create trait-based collection targets
# These are common trait components used across all GEMM Universal kernels
set(GEMM_UNIVERSAL_PIPELINES "mem;compv3;compv4")
set(GEMM_UNIVERSAL_EPILOGUES "default;cshuffle")
set(GEMM_UNIVERSAL_SCHEDULERS "intrawave;interwave")
foreach(pipeline IN LISTS GEMM_UNIVERSAL_PIPELINES)
add_custom_target(benchmark_gemm_universal_${pipeline}_pipeline)
endforeach()
foreach(epilogue IN LISTS GEMM_UNIVERSAL_EPILOGUES)
add_custom_target(benchmark_gemm_universal_${epilogue}_epilogue)
endforeach()
foreach(scheduler IN LISTS GEMM_UNIVERSAL_SCHEDULERS)
add_custom_target(benchmark_gemm_universal_${scheduler}_scheduler)
endforeach()
# Build individual targets for each datatype/layout combination
foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE)
foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT)
build_individual_gemm_universal_targets(${dt} ${l})
endforeach()
endforeach()
endif()

View File

@@ -2,12 +2,12 @@
"tile_config": {
"tile_m": {
"values": [
64
128
]
},
"tile_n": {
"values": [
192
128
]
},
"tile_k": {
@@ -17,12 +17,12 @@
},
"warp_m": {
"values": [
2
4
]
},
"warp_n": {
"values": [
2
1
]
},
"warp_k": {
@@ -32,24 +32,24 @@
},
"warp_tile_m": {
"values": [
32
16
]
},
"warp_tile_n": {
"values": [
32
16
]
},
"warp_tile_k": {
"values": [
8
16
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv4"
"mem"
]
},
"scheduler": {
@@ -59,7 +59,7 @@
},
"epilogue": {
"values": [
"cshuffle"
"default"
]
},
"pad_m": {
@@ -83,5 +83,5 @@
]
}
},
"k_block_per_cu": 1
"k_block_per_cu": 2
}

View File

@@ -1,4 +1,3 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -11,12 +11,12 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/common/utils.hpp"
#include "gemm_profiler.hpp"
#include "gemm_common.hpp"
// The kernel header is included via the compile command line with -include flag
// It defines SelectedKernel struct and KERNEL_NAME
// DataTypeTraits are now defined in gemm_common.hpp
// Create argument parser
inline auto create_args(int argc, char* argv[])
@@ -77,12 +77,12 @@ inline auto create_args(int argc, char* argv[])
void benchmark_single(const ck_tile::ArgParser& arg_parser)
{
// Use ck_tile::DataTypeTraits to get the actual type names from the generated header
// Use DataTypeTraits to get the actual type names from the generated header
// The generated header defines ADataType, BDataType, AccDataType, CDataType
std::string dtype_a = ck_tile::DataTypeTraits<ADataType>::name;
std::string dtype_b = ck_tile::DataTypeTraits<BDataType>::name;
std::string dtype_acc = ck_tile::DataTypeTraits<AccDataType>::name;
std::string dtype_c = ck_tile::DataTypeTraits<CDataType>::name;
std::string dtype_a = DataTypeTraits<ADataType>::name;
std::string dtype_b = DataTypeTraits<BDataType>::name;
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
std::string dtype_c = DataTypeTraits<CDataType>::name;
// Layout names from the layout types
std::string layout_a = ALayout::name;

View File

@@ -0,0 +1,100 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
//[TODO] This can be moved to commons
// DataTypeTraits for all supported types
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <>
struct DataTypeTraits<ck_tile::int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
// Helper function to determine if a layout is row-major
template <typename Layout>
constexpr auto is_row_major(Layout)
{
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// Structure to hold kernel traits for dispatcher
struct KernelTraits
{
std::string pipeline; // compv3, compv4, mem
std::string scheduler; // intrawave, interwave
std::string epilogue; // cshuffle, default
bool pad_m;
bool pad_n;
bool pad_k;
bool persistent;
// Constructor with defaults
KernelTraits()
: pipeline("compv3"),
scheduler("intrawave"),
epilogue("cshuffle"),
pad_m(false),
pad_n(false),
pad_k(false),
persistent(false)
{
}
};

View File

@@ -0,0 +1,295 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import os
import argparse
import importlib.util
import multiprocessing
import concurrent.futures
def _import_gemm_kernel_builder():
"""Import validation utilities from commons directory."""
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
# Load the module dynamically
spec = importlib.util.spec_from_file_location(
"gemm_instance_builder",
os.path.join(parent_dir, "gemm_instance_builder.py"),
)
gemm_builder_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(gemm_builder_module)
return gemm_builder_module.GemmKernelBuilder
GemmKernelBuilder = _import_gemm_kernel_builder()
class GemmUniversalKernelBuilder(GemmKernelBuilder):
def __init__(
self,
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
config_json=None,
):
super().__init__(
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
)
def _generate_all_individual(self, num_workers=None):
"""Generate individual kernel files for separate compilation with parallel processing"""
if num_workers is None:
num_workers = min(
multiprocessing.cpu_count(), 8
) # Limit to avoid memory issues
tile_configs = self._get_tile_configs()
trait_combos = self._generate_trait_combinations()
# Prepare work items for parallel processing
work_items = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
work_items.append(
(
tile_config,
trait_combo,
self.kernel_name_prefix,
self.working_path,
self.gpu_target,
self.datatype,
self.layout,
self.config_json,
)
)
print(
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
)
print(f" Tile configs: {len(tile_configs)}")
print(f" Trait combinations: {len(trait_combos)}")
print(f" Total kernels: {len(work_items)}")
# Show first few work items for debugging
if work_items:
print(" First work item example:")
tile_config, trait_combo = work_items[0][:2]
print(f" Tile config: {tile_config}")
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
# Process work items in parallel
kernel_list = []
completed = 0
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
# Submit all work items
print(f" Submitting {len(work_items)} tasks to executor...")
future_to_item = {
executor.submit(_generate_single_kernel_individual, item): item
for item in work_items
}
print(" All tasks submitted, waiting for completion...")
# Collect results with progress reporting
for future in concurrent.futures.as_completed(future_to_item):
completed += 1
if completed % 100 == 0 or completed == len(work_items):
print(
f" Progress: {completed}/{len(work_items)} kernels generated"
)
try:
result = future.result()
if result:
kernel_list.append(result)
except Exception as exc:
item = future_to_item[future]
print(f"Kernel generation failed for {item}: {exc}")
# Sort kernel list for consistent ordering
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
# Generate CMake include file for individual targets
self._generate_cmake_individual_targets(kernel_list)
print(
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
)
def _generate_single_kernel_individual(work_item):
"""Worker function to generate a single individual kernel file"""
(
tile_config,
trait_combo,
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
config_json,
) = work_item
# Create a temporary builder instance for this worker
builder = GemmUniversalKernelBuilder(
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
)
try:
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
)
# Create simplified filename without the "gemm_universal_" prefix
# Remove "gemm_universal_" from the beginning of kernel_name for the filename
simplified_name = kernel_name
if simplified_name.startswith("gemm_universal_"):
simplified_name = simplified_name[
len(kernel_name_prefix) + 1 :
] # Remove "gemm_universal" prefix
# Write individual header file
header_file = working_path / f"gemm_universal_single_{simplified_name}.hpp"
with open(header_file, "w") as f:
f.write(instance_code)
return (kernel_name, trait_combo, tile_config)
except Exception as e:
print(f"Error generating individual kernel: {e}")
return None
def main():
parser = argparse.ArgumentParser(
description="GEMM Universal kernel instance builder with parallel support"
)
parser.add_argument("--working_path", required=True, help="Working directory path")
parser.add_argument(
"--gpu_target",
required=True,
help="GPU target architecture",
)
parser.add_argument(
"--datatype",
required=True,
choices=["fp16", "fp8", "bf16", "bf8"],
help="Data type",
)
parser.add_argument(
"--layout",
required=True,
choices=["rcr", "rrr", "ccr", "crr"],
help="Matrix layout",
)
parser.add_argument("--config_json", help="Configuration JSON file")
parser.add_argument(
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
)
parser.add_argument(
"--gen_all_individual",
action="store_true",
help="Generate individual kernel files",
)
parser.add_argument(
"--gen_single", action="store_true", help="Generate a single kernel file"
)
parser.add_argument("--kernel_name", help="Kernel name for single generation")
parser.add_argument(
"--tile_config", help="Tile configuration string for single generation"
)
parser.add_argument(
"--trait_combo", help="Trait combination string for single generation"
)
parser.add_argument(
"--list_kernels",
action="store_true",
help="List kernel configurations without generating files",
)
args = parser.parse_args()
assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], (
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])"
)
layout_parts = args.layout.lower()
assert len(layout_parts) == 3, (
f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
)
assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], (
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)"
)
assert layout_parts[2] == "r", (
f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
)
kernel_name_prefix = "gemm_universal"
builder = GemmUniversalKernelBuilder(
kernel_name_prefix,
args.working_path,
args.gpu_target,
args.datatype,
args.layout,
args.config_json,
)
if args.list_kernels:
builder._list_kernels()
elif args.gen_single:
# Generate a single kernel file input validation
if not args.kernel_name or not args.tile_config or not args.trait_combo:
parser.error(
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
)
# Parse tile config
tile_parts = args.tile_config.split("_")
tile_dims = tile_parts[0].split("x")
warp_dims = tile_parts[1].split("x")
warp_tile_dims = tile_parts[2].split("x")
tile_config = {
"tile_m": int(tile_dims[0]),
"tile_n": int(tile_dims[1]),
"tile_k": int(tile_dims[2]),
"warp_m": int(warp_dims[0]),
"warp_n": int(warp_dims[1]),
"warp_k": int(warp_dims[2]),
"warp_tile_m": int(warp_tile_dims[0]),
"warp_tile_n": int(warp_tile_dims[1]),
"warp_tile_k": int(warp_tile_dims[2]),
}
# Parse trait combo
trait_parts = args.trait_combo.split("_")
trait_combo = (
trait_parts[0], # pipeline
trait_parts[1], # epilogue
trait_parts[2], # scheduler
trait_parts[3] == "True", # pad_m
trait_parts[4] == "True", # pad_n
trait_parts[5] == "True", # pad_k
trait_parts[6] == "True", # persistent
)
# Generate the kernel
builder._generate_kernel_instance(
tile_config,
trait_combo,
)
elif args.gen_all_individual:
# Generate all individual kernel files
builder._generate_all_individual(args.num_workers)
else:
parser.error(
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
)
if __name__ == "__main__":
main()

View File

@@ -1,4 +1,3 @@
#!/usr/bin/env python
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
@@ -660,7 +659,6 @@ def validate_whole_wg_cover_configuration(
)
if not wg_cover_core_valid:
print("I am here 3")
logging.debug(
f"whole workgroup cover failed for Matrix B distribution: {wg_cover_core_error}"
)

View File

@@ -1,41 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
// Helper function to determine if a layout is row-major
template <typename Layout>
constexpr auto is_row_major(Layout)
{
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// Structure to hold kernel traits for dispatcher
struct KernelTraits
{
std::string pipeline; // compv3, compv4, mem
std::string scheduler; // intrawave, interwave
std::string epilogue; // cshuffle, default
bool pad_m;
bool pad_n;
bool pad_k;
bool persistent;
// Constructor with defaults
KernelTraits()
: pipeline("compv3"),
scheduler("intrawave"),
epilogue("cshuffle"),
pad_m(false),
pad_n(false),
pad_k(false),
persistent(false)
{
}
};

View File

@@ -1,891 +0,0 @@
#!/usr/bin/env python
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import os
import json
import argparse
import itertools
import multiprocessing
import concurrent.futures
from pathlib import Path
import logging
import importlib.util
def _import_validation_utils():
"""Import validation utilities from commons directory."""
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
# Load the module dynamically
spec = importlib.util.spec_from_file_location(
"validation_utils",
os.path.join(parent_dir, "commons", "gemm_validation_utils.py"),
)
validation_utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(validation_utils)
return validation_utils
# Import validation functions
_validation_utils = _import_validation_utils()
is_tile_config_valid = _validation_utils.is_tile_config_valid
is_trait_combination_valid = _validation_utils.is_trait_combination_valid
get_dtype_string = _validation_utils.get_dtype_string
get_abcd_layouts = _validation_utils.get_abcd_layouts
logging.basicConfig(level=logging.INFO)
class GemmMultiDKernelBuilder:
def __init__(
self,
working_path,
gpu_target,
datatype,
layout,
elementwise_function,
config_json=None,
):
self.working_path = Path(working_path)
self.gpu_target = gpu_target
self.datatype = datatype
self.layout = layout
self.elementwise_function = elementwise_function
self.config_json = config_json
# Create working directory if it doesn't exist
self.working_path.mkdir(parents=True, exist_ok=True)
# Load configuration
if config_json and os.path.exists(config_json):
with open(config_json, "r") as f:
self.config = json.load(f)
def write_kernel_list(self):
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
# Get configurations using comprehensive validation
tile_configs = self._get_tile_configs(fast_mode=False)
trait_combos = self._generate_trait_combinations()
kernel_list = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
(
pipeline,
epilogue,
scheduler,
pad_m,
pad_n,
pad_k,
persistent,
) = trait_combo
# Create kernel name with proper boolean capitalization
kernel_name = f"gemm_multi_d_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
# Create tile configuration string
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
kernel_name += f"_{tile_str}"
kernel_list.append(
{
"name": kernel_name,
"tile_config": tile_config,
"trait_combo": trait_combo,
}
)
# Write kernel count
with open(self.working_path / "gemm_multi_d_kernel_count.txt", "w") as f:
f.write(str(len(kernel_list)))
# Write kernel list
with open(self.working_path / "gemm_multi_d_kernel_list.txt", "w") as f:
for kernel in kernel_list:
# Format: kernel_name|tile_config|trait_combo
tile_config = kernel["tile_config"]
trait_combo = kernel["trait_combo"]
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
trait_str = (
f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_"
+ "_".join(str(x) for x in trait_combo[3:])
)
f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n")
print(f"Listed {len(kernel_list)} kernel configurations")
def _get_tile_configs(self, fast_mode=False):
"""Get tile configurations for the current datatype and layout"""
tile_config = self.config["tile_config"]
# Generate values in the config if default range is given
if tile_config.get("tile_m").get("values") is None:
tile_config.get("tile_m")["values"] = self._generate_values(
tile_config.get("tile_m").get("min"),
tile_config.get("tile_m").get("max"),
tile_config.get("tile_m").get("step"),
)
if tile_config.get("tile_n").get("values") is None:
tile_config.get("tile_n")["values"] = self._generate_values(
tile_config.get("tile_n").get("min"),
tile_config.get("tile_n").get("max"),
tile_config.get("tile_n").get("step"),
)
if tile_config.get("tile_k").get("values") is None:
tile_config.get("tile_k")["values"] = self._generate_values(
tile_config.get("tile_k").get("min"),
tile_config.get("tile_k").get("max"),
tile_config.get("tile_k").get("step"),
)
# Get all possible values for each parameter
tile_m_values = tile_config.get("tile_m").get("values")
tile_n_values = tile_config.get("tile_n").get("values")
tile_k_values = tile_config.get("tile_k").get("values")
warp_m_values = tile_config.get("warp_m").get("values")
warp_n_values = tile_config.get("warp_n").get("values")
warp_k_values = tile_config.get("warp_k").get("values")
warp_tile_m_values = tile_config.get("warp_tile_m").get("values")
warp_tile_n_values = tile_config.get("warp_tile_n").get("values")
warp_tile_k_values = tile_config.get("warp_tile_k").get("values")
# Generate all combinations
configs = []
for tile_m in tile_m_values:
for tile_n in tile_n_values:
for tile_k in tile_k_values:
for warp_m in warp_m_values:
for warp_n in warp_n_values:
for warp_k in warp_k_values:
for warp_tile_m in warp_tile_m_values:
for warp_tile_n in warp_tile_n_values:
for warp_tile_k in warp_tile_k_values:
# Validate configuration
if self._validate_tile_config(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
fast_mode=fast_mode,
):
configs.append(
{
"tile_m": tile_m,
"tile_n": tile_n,
"tile_k": tile_k,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_k": warp_k,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"warp_tile_k": warp_tile_k,
}
)
return configs
def _generate_values(self, min_val, max_val, step):
"""Generate a list of values from min to max with the given step"""
values = []
val = min_val
while val <= max_val:
values.append(val)
val += step
return values
def _validate_tile_config(
self,
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
pipeline="compv4", # Default pipeline for validation
fast_mode=False, # Add fast mode option
):
"""Validate that tile configuration is reasonable"""
if fast_mode:
# Fast validation for listing - only basic sanity checks
if tile_m <= 0 or tile_n <= 0 or tile_k <= 0:
return False
if warp_m <= 0 or warp_n <= 0 or warp_k <= 0:
return False
if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0:
return False
# Basic divisibility check
if tile_m % (warp_m * warp_tile_m) != 0:
return False
if tile_n % (warp_n * warp_tile_n) != 0:
return False
if tile_k % (warp_k * warp_tile_k) != 0:
return False
return True
else:
# Full validation for generation
# Determine data types for validation
a_datatype = self.datatype
b_datatype = self.datatype
c_datatype = self.datatype
layout = self.layout
# Special handling for certain data types
if self.datatype in ["fp8", "bf8"]:
c_datatype = "fp16"
# Use the comprehensive validation function
return is_tile_config_valid(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
a_datatype,
b_datatype,
c_datatype,
pipeline,
layout,
self.gpu_target,
)
def _generate_trait_combinations(self):
"""Generate all combinations of traits"""
trait_config = self.config["trait_config"]
pipelines = trait_config.get("pipeline").get("values")
epilogues = trait_config.get("epilogue").get("values")
schedulers = trait_config.get("scheduler").get("values")
pad_m_values = trait_config.get("pad_m").get("values")
pad_n_values = trait_config.get("pad_n").get("values")
pad_k_values = trait_config.get("pad_k").get("values")
persistent_values = trait_config.get("persistent").get("values")
all_combinations = list(
itertools.product(
pipelines,
epilogues,
schedulers,
pad_m_values,
pad_n_values,
pad_k_values,
persistent_values,
)
)
# Filter out unsupported trait combinations
combinations = []
for combo in all_combinations:
pipeline, epilogue, scheduler = combo[:3]
if is_trait_combination_valid(pipeline, epilogue, scheduler):
combinations.append(combo)
else:
logging.debug(
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
)
return combinations
def _generate_kernel_instance(
self, tile_config, trait_combo, k_block_per_cu, is_header=True
):
"""Generate a single kernel instance"""
(
pipeline,
epilogue,
scheduler,
pad_m,
pad_n,
pad_k,
persistent,
) = trait_combo
# Create kernel name with proper boolean capitalization
kernel_name = f"gemm_multi_d_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
# Create tile configuration string
tile_str = (
f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
)
tile_str += (
f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
)
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
kernel_name += f"_{tile_str}"
# Map pipeline names to the correct pipeline implementation
pipeline_impl_map = {
"mem": "ck_tile::GemmPipelineAgBgCrMem",
"compv3": "ck_tile::GemmPipelineAgBgCrCompV3",
"compv4": "ck_tile::GemmPipelineAgBgCrCompV4",
}
# Map pipeline names to base pipeline for hot loop detection
base_pipeline_map = {
"mem": "ck_tile::BaseGemmPipelineAgBgCrMem",
"compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3",
"compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4",
}
# Map scheduler names to the correct enum values
scheduler_type_map = {
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
"interwave": "ck_tile::GemmPipelineScheduler::Interwave",
"default": "ck_tile::GemmPipelineScheduler::Default",
}
# Determine accumulator type based on datatype
acc_type = "float"
# Determine output type
c_type = self.datatype
if self.datatype in ["fp8", "bf8"]:
c_type = "fp16"
# Determine layouts based on self.layout
a_layout, b_layout, c_layout, ds_layout = get_abcd_layouts(self.layout)
# Generate kernel instance code using the correct API
pragma_line = "#pragma once\n" if is_header else ""
instance_code = f"""// Generated kernel instance for {kernel_name}
{pragma_line}
#include <cstdint>
#include <utility>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
using ADataType = {get_dtype_string(self.datatype)};
using BDataType = {get_dtype_string(self.datatype)};
using AccDataType = {acc_type};
using CDataType = {get_dtype_string(c_type)};
using D0DataType = {get_dtype_string(self.datatype)};
using D1DataType = {get_dtype_string(self.datatype)};
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
using ALayout = {a_layout};
using BLayout = {b_layout};
using CLayout = {c_layout};
using D0Layout = {ds_layout[0]};
using D1Layout = {ds_layout[1]};
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
using ElementWiseFn = ck_tile::element_wise::{self.elementwise_function};
// Kernel name for display
constexpr const char* KERNEL_NAME = "{kernel_name}";
// Wrapper for simplified launch interface
struct SelectedKernel {{
// Tile configuration
static constexpr ck_tile::index_t BlockSize = 256;
static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]};
static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]};
static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]};
static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]};
static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]};
static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]};
static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]};
static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]};
static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]};
// Traits
static constexpr bool kPadM = {"true" if pad_m in [True, "true"] else "false"};
static constexpr bool kPadN = {"true" if pad_n in [True, "true"] else "false"};
static constexpr bool kPadK = {"true" if pad_k in [True, "true"] else "false"};
static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"};
static constexpr bool TransposeC = false;
// Tile shape
using TileShape = ck_tile::TileGemmShape<
ck_tile::sequence<TileM, TileN, TileK>,
ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N, WarpPerBlock_K>,
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>>;
// Tile partitioner
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<TileShape, 8, 4>;
// Traits
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
// Pipeline problem
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
Traits>;
// Base pipeline for hot loop detection
using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}<GemmPipelineProblem>;
static float launch(const ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream) {{
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC>,
scheduler>;
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {{
[[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;
// Epilogue
"""
# Add epilogue configuration based on type
if epilogue == "cshuffle":
instance_code += """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ElementWiseFn,
TilePartitioner::MPerBlock, // kM_
TilePartitioner::NPerBlock, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC, // isCTransposed_
memory_operation>; // MemoryOperation_
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;
"""
else: # default epilogue
instance_code += """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ElementWiseFn,
TilePartitioner::MPerBlock, // kM_
TilePartitioner::NPerBlock, // kN_
kPadM,
kPadN,
WarpTileM, // kMPerXdl_
WarpTileN, // kNPerXdl_
WarpTileK, // kKPerXdl_
TransposeC>; // isCTransposed_
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;
"""
instance_code += f"""
// Kernel type
using GemmKernelMultiD = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
// Make kernel arguments
auto kargs = GemmKernelMultiD::MakeKernelArgs(args);
if (!GemmKernelMultiD::IsSupportedArgument(kargs)) {{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}}
// Get grid and block sizes
const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = GemmKernelMultiD::BlockSize();
if(stream.log_level_ > 0) {{
std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n'
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
<< std::endl;
}}
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
return ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernelMultiD{{}}, grids, blocks, 0, kargs));
}};
if(args.k_batch == 1) {{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}});
}} else {{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{{}});
}}
}}
}};
"""
return kernel_name, instance_code
def run(self, num_workers=None):
"""Run the builder to generate individual kernel files"""
# Generate individual kernel files
self.generate_individual(num_workers)
def generate_individual(self, num_workers=None):
"""Generate individual kernel files for separate compilation with parallel processing"""
if num_workers is None:
num_workers = min(
multiprocessing.cpu_count(), 8
) # Limit to avoid memory issues
tile_configs = self._get_tile_configs()
trait_combos = self._generate_trait_combinations()
k_block_per_cu = self.config.get("k_block_per_cu")
if k_block_per_cu is None:
k_block_per_cu = 1
# Prepare work items for parallel processing
work_items = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
work_items.append(
(
tile_config,
trait_combo,
k_block_per_cu,
self.working_path,
self.gpu_target,
self.datatype,
self.layout,
self.elementwise_function,
self.config_json,
)
)
print(
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
)
print(f" Tile configs: {len(tile_configs)}")
print(f" Trait combinations: {len(trait_combos)}")
print(f" Total kernels: {len(work_items)}")
# Show first few work items for debugging
if work_items:
print(" First work item example:")
tile_config, trait_combo = work_items[0][:2]
print(f" Tile config: {tile_config}")
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
# Process work items in parallel
kernel_list = []
completed = 0
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
# Submit all work items
print(f" Submitting {len(work_items)} tasks to executor...")
future_to_item = {
executor.submit(_generate_single_kernel_individual, item): item
for item in work_items
}
print(" All tasks submitted, waiting for completion...")
# Collect results with progress reporting
for future in concurrent.futures.as_completed(future_to_item):
completed += 1
if completed % 100 == 0 or completed == len(work_items):
print(
f" Progress: {completed}/{len(work_items)} kernels generated"
)
try:
result = future.result()
if result:
kernel_list.append(result)
except Exception as exc:
item = future_to_item[future]
print(f"Kernel generation failed for {item}: {exc}")
# Sort kernel list for consistent ordering
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
# Generate CMake include file for individual targets
self._generate_cmake_individual_targets(kernel_list)
print(
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
)
def _generate_cmake_individual_targets(self, kernel_list):
"""Generate CMake include file that creates individual targets"""
cmake_code = f"""# Generated CMake file for individual GEMM Multi D targets
# Datatype: {self.datatype}, Layout: {self.layout}
"""
for kernel_name, trait_combo, tile_config in kernel_list:
pipeline, epilogue, scheduler = trait_combo[:3]
# Format tile config for CMake function
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
trait_str = f"{pipeline}_{epilogue}_{scheduler}_" + "_".join(
str(x) for x in trait_combo[3:]
)
cmake_code += f'create_individual_gemm_multi_d_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n'
# Write CMake include file
with open(
self.working_path / "gemm_multi_d_individual_targets.cmake", "w"
) as f:
f.write(cmake_code)
def _generate_single_kernel_individual(work_item):
"""Worker function to generate a single individual kernel file"""
(
tile_config,
trait_combo,
k_block_per_cu,
working_path,
gpu_target,
datatype,
layout,
elementwise_function,
config_json,
) = work_item
# Create a temporary builder instance for this worker
builder = GemmMultiDKernelBuilder(
working_path,
gpu_target,
datatype,
layout,
elementwise_function,
config_json,
)
try:
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo, k_block_per_cu
)
# Create simplified filename without the "gemm_multi_d_" prefix
# Remove "gemm_multi_d_" from the beginning of kernel_name for the filename
simplified_name = kernel_name
if simplified_name.startswith("gemm_multi_d_"):
simplified_name = simplified_name[13:] # Remove "gemm_multi_d_" prefix
# Write individual header file
header_file = working_path / f"gemm_multi_d_single_{simplified_name}.hpp"
with open(header_file, "w") as f:
f.write(instance_code)
return (kernel_name, trait_combo, tile_config)
except Exception as e:
print(f"Error generating individual kernel: {e}")
return None
def main():
parser = argparse.ArgumentParser(
description="GEMM Multi D kernel instance builder with parallel support"
)
parser.add_argument("--working_path", required=True, help="Working directory path")
parser.add_argument("--gpu_target", required=True, help="GPU target architecture")
parser.add_argument(
"--datatype",
required=True,
choices=["fp16"],
help="Data type",
)
parser.add_argument(
"--layout",
required=True,
choices=["rcrr", "rrrr", "ccrr", "crrr"],
help="Matrix layout",
)
parser.add_argument(
"--elementwise_function",
required=True,
help="Specify what element wise function for D, e.g. mul, add, passthrough",
)
parser.add_argument("--config_json", help="Configuration JSON file")
parser.add_argument(
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
)
parser.add_argument(
"--gen_all_individual",
action="store_true",
help="Generate individual kernel files",
)
parser.add_argument(
"--gen_single", action="store_true", help="Generate a single kernel file"
)
parser.add_argument("--kernel_name", help="Kernel name for single generation")
parser.add_argument(
"--tile_config", help="Tile configuration string for single generation"
)
parser.add_argument(
"--trait_combo", help="Trait combination string for single generation"
)
parser.add_argument(
"--list_kernels",
action="store_true",
help="List kernel configurations without generating files",
)
args = parser.parse_args()
assert args.datatype in ["fp16"], (
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16])"
)
layout_parts = args.layout.lower()
assert len(layout_parts) == 4, (
f"Invalid layout string: {args.layout} (must be 4 characters like 'rcrr' where r stands for row major and c stands for column major)"
)
assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], (
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)"
)
assert layout_parts[2] == "r" and layout_parts[3] == "r", (
f"Invalid matrix_c or d dimension in layout: {layout_parts[2]} andf {layout_parts[3]} (must be 'r' only as currently we are supporting only row major)"
)
# Elementwise function name validation
elementwise_function = args.elementwise_function.lower()
valid_functions = ["mul", "add", "passthrough"]
if elementwise_function not in valid_functions:
raise ValueError(
f"Invalid elementwise function: {elementwise_function}. "
f"Valid options are: {', '.join(valid_functions)}"
)
# Set the function name based on the elementwise function
if elementwise_function == "mul":
function_name = "MultiDMultiply"
elif elementwise_function == "add":
function_name = "MultiDAdd"
elif elementwise_function == "passthrough":
function_name = "PassThrough"
args.elementwise_function = function_name
# Create builder
builder = GemmMultiDKernelBuilder(
args.working_path,
args.gpu_target,
args.datatype,
args.layout,
args.elementwise_function,
args.config_json,
)
if args.list_kernels:
builder.write_kernel_list()
elif args.gen_single:
# Generate a single kernel file
if not args.kernel_name or not args.tile_config or not args.trait_combo:
parser.error(
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
)
# Parse tile config
tile_parts = args.tile_config.split("_")
tile_dims = tile_parts[0].split("x")
warp_dims = tile_parts[1].split("x")
warp_tile_dims = tile_parts[2].split("x")
tile_config = {
"tile_m": int(tile_dims[0]),
"tile_n": int(tile_dims[1]),
"tile_k": int(tile_dims[2]),
"warp_m": int(warp_dims[0]),
"warp_n": int(warp_dims[1]),
"warp_k": int(warp_dims[2]),
"warp_tile_m": int(warp_tile_dims[0]),
"warp_tile_n": int(warp_tile_dims[1]),
"warp_tile_k": int(warp_tile_dims[2]),
}
# Parse trait combo
trait_parts = args.trait_combo.split("_")
trait_combo = (
trait_parts[0], # pipeline
trait_parts[1], # epilogue
trait_parts[2], # scheduler
trait_parts[3] == "True", # pad_m
trait_parts[4] == "True", # pad_n
trait_parts[5] == "True", # pad_k
trait_parts[6] == "True", # persistent
)
k_block_per_cu = builder.config.get("k_block_per_cu")
if k_block_per_cu is None:
k_block_per_cu = 1
# Generate the kernel
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo, k_block_per_cu
)
# Write the file
simplified_name = kernel_name
if simplified_name.startswith("gemm_multi_d_"):
simplified_name = simplified_name[13:]
header_file = (
builder.working_path / f"gemm_multi_d_single_{simplified_name}.hpp"
)
with open(header_file, "w") as f:
f.write(instance_code)
print(f"Generated {header_file}")
elif args.gen_all_individual:
# Generate all individual kernel files
builder.run(args.num_workers)
else:
parser.error(
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
)
if __name__ == "__main__":
main()

View File

@@ -1,83 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
// Helper function to determine if a layout is row-major
template <typename Layout>
constexpr auto is_row_major(Layout)
{
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// Structure to hold kernel traits for dispatcher
struct KernelTraits
{
std::string pipeline; // preshufflev2
std::string scheduler; // intrawave, interwave, default
std::string epilogue; // cshuffle, default
bool pad_m;
bool pad_n;
bool pad_k;
bool persistent;
// Constructor with defaults
KernelTraits()
: pipeline("preshufflev2"),
scheduler("default"),
epilogue("default"),
pad_m(false),
pad_n(false),
pad_k(false),
persistent(false)
{
}
};
// Helper to extract traits from kernel name
inline KernelTraits extract_traits_from_name(const std::string& kernel_name)
{
KernelTraits traits;
// Extract pipeline
if(kernel_name.find("preshufflev2") != std::string::npos)
{
traits.pipeline = "preshufflev2";
}
// Extract scheduler
if(kernel_name.find("interwave") != std::string::npos)
{
traits.scheduler = "interwave";
}
else if(kernel_name.find("intrawave") != std::string::npos)
{
traits.scheduler = "intrawave";
}
else
{
traits.scheduler = "default";
}
// Extract epilogue
if(kernel_name.find("default") != std::string::npos &&
kernel_name.find("default_") == std::string::npos)
{
traits.epilogue = "default";
}
else
{
traits.epilogue = "cshuffle";
}
// Padding flags would need to be extracted from the kernel configuration
// For now, we'll leave them as false
return traits;
}

View File

@@ -1,894 +0,0 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import argparse
import os
import json
import itertools
import logging
import multiprocessing
import concurrent.futures
from pathlib import Path
import importlib.util
def _import_validation_utils():
"""Import validation utilities from commons directory."""
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
# Load the module dynamically
spec = importlib.util.spec_from_file_location(
"validation_utils",
os.path.join(parent_dir, "commons", "gemm_validation_utils.py"),
)
validation_utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(validation_utils)
return validation_utils
# Import validation functions
_validation_utils = _import_validation_utils()
is_tile_config_valid = _validation_utils.is_tile_config_valid
is_trait_combination_valid = _validation_utils.is_trait_combination_valid
get_dtype_string = _validation_utils.get_dtype_string
get_abc_layouts = _validation_utils.get_abc_layouts
logging.basicConfig(level=logging.INFO)
class GemmPreshuffleKernelBuilder:
def __init__(self, working_path, gpu_target, datatype, layout, config_json=None):
self.working_path = Path(working_path)
self.gpu_target = gpu_target
self.datatype = datatype
self.layout = layout
self.config_json = config_json
# Create working directory if it doesn't exist
self.working_path.mkdir(parents=True, exist_ok=True)
# Load configuration
if config_json and os.path.exists(config_json):
with open(config_json, "r") as f:
self.config = json.load(f)
def write_kernel_list(self):
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
# Get configurations using comprehensive validation
tile_configs = self._get_tile_configs(fast_mode=False)
trait_combos = self._generate_trait_combinations()
kernel_list = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
(
pipeline,
epilogue,
scheduler,
pad_m,
pad_n,
pad_k,
persistent,
) = trait_combo
# Create kernel name with proper boolean capitalization
kernel_name = f"gemm_preshuffle_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
# Create tile configuration string
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
kernel_name += f"_{tile_str}"
kernel_list.append(
{
"name": kernel_name,
"tile_config": tile_config,
"trait_combo": trait_combo,
}
)
# Write kernel count
with open(self.working_path / "gemm_preshuffle_kernel_count.txt", "w") as f:
f.write(str(len(kernel_list)))
# Write kernel list
with open(self.working_path / "gemm_preshuffle_kernel_list.txt", "w") as f:
for kernel in kernel_list:
# Format: kernel_name|tile_config|trait_combo
tile_config = kernel["tile_config"]
trait_combo = kernel["trait_combo"]
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
trait_str = (
f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_"
+ "_".join(str(x) for x in trait_combo[3:])
)
f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n")
print(f"Listed {len(kernel_list)} kernel configurations")
def _get_tile_configs(self, fast_mode=False):
"""Get tile configurations for the current datatype and layout"""
tile_config = self.config["tile_config"]
# Generate values in the config if default range is given
if tile_config.get("tile_m").get("values") is None:
tile_config.get("tile_m")["values"] = self._generate_values(
tile_config.get("tile_m").get("min"),
tile_config.get("tile_m").get("max"),
tile_config.get("tile_m").get("step"),
)
if tile_config.get("tile_n").get("values") is None:
tile_config.get("tile_n")["values"] = self._generate_values(
tile_config.get("tile_n").get("min"),
tile_config.get("tile_n").get("max"),
tile_config.get("tile_n").get("step"),
)
if tile_config.get("tile_k").get("values") is None:
tile_config.get("tile_k")["values"] = self._generate_values(
tile_config.get("tile_k").get("min"),
tile_config.get("tile_k").get("max"),
tile_config.get("tile_k").get("step"),
)
# Get all possible values for each parameter
tile_m_values = tile_config.get("tile_m").get("values")
tile_n_values = tile_config.get("tile_n").get("values")
tile_k_values = tile_config.get("tile_k").get("values")
warp_m_values = tile_config.get("warp_m").get("values")
warp_n_values = tile_config.get("warp_n").get("values")
warp_k_values = tile_config.get("warp_k").get("values")
warp_tile_m_values = tile_config.get("warp_tile_m").get("values")
warp_tile_n_values = tile_config.get("warp_tile_n").get("values")
warp_tile_k_values = tile_config.get("warp_tile_k").get("values")
# Generate all combinations
configs = []
for tile_m in tile_m_values:
for tile_n in tile_n_values:
for tile_k in tile_k_values:
for warp_m in warp_m_values:
for warp_n in warp_n_values:
for warp_k in warp_k_values:
for warp_tile_m in warp_tile_m_values:
for warp_tile_n in warp_tile_n_values:
for warp_tile_k in warp_tile_k_values:
# Validate configuration
if self._validate_tile_config(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
fast_mode=fast_mode,
):
configs.append(
{
"tile_m": tile_m,
"tile_n": tile_n,
"tile_k": tile_k,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_k": warp_k,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"warp_tile_k": warp_tile_k,
}
)
return configs
def _generate_values(self, min_val, max_val, step):
"""Generate a list of values from min to max with the given step"""
values = []
val = min_val
while val <= max_val:
values.append(val)
val += step
return values
def _generate_trait_combinations(self):
"""Generate all combinations of traits"""
if "traits" in self.config:
# Old format
traits = self.config["traits"]
pipelines = traits["pipelines"]
epilogues = traits["epilogues"]
schedulers = traits["schedulers"]
padding = self.config["padding"]
persistent = self.config["persistent"]
all_combinations = list(
itertools.product(
pipelines,
epilogues,
schedulers,
padding["pad_m"],
padding["pad_n"],
padding["pad_k"],
persistent,
)
)
# Filter out unsupported trait combinations
combinations = []
for combo in all_combinations:
pipeline, epilogue, scheduler = combo[:3]
if is_trait_combination_valid(pipeline, epilogue, scheduler):
combinations.append(combo)
else:
logging.debug(
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
)
elif "trait_config" in self.config:
# New format
trait_config = self.config["trait_config"]
pipelines = trait_config.get("pipeline", {}).get("values", ["preshufflev2"])
epilogues = trait_config.get("epilogue", {}).get("values", ["default"])
schedulers = trait_config.get("scheduler", {}).get("values", ["default"])
pad_m_values = trait_config.get("pad_m", {}).get("values", [False])
pad_n_values = trait_config.get("pad_n", {}).get("values", [False])
pad_k_values = trait_config.get("pad_k", {}).get("values", [False])
persistent_values = trait_config.get("persistent", {}).get(
"values", [False]
)
all_combinations = list(
itertools.product(
pipelines,
epilogues,
schedulers,
pad_m_values,
pad_n_values,
pad_k_values,
persistent_values,
)
)
# Filter out unsupported trait combinations
combinations = []
for combo in all_combinations:
pipeline, epilogue, scheduler = combo[:3]
if is_trait_combination_valid(pipeline, epilogue, scheduler):
combinations.append(combo)
else:
logging.debug(
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
)
else:
# Fallback to minimal default
combinations = [
("preshufflev2", "default", "default", False, False, False, False)
]
return combinations
def _validate_tile_config(
self,
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
pipeline="preshufflev2", # Default pipeline for validation
fast_mode=False, # Add fast mode option
):
"""Validate that tile configuration is reasonable"""
if fast_mode:
# Fast validation for listing - only basic sanity checks
if tile_m <= 0 or tile_n <= 0 or tile_k <= 0:
return False
if warp_m <= 0 or warp_n <= 0 or warp_k <= 0:
return False
if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0:
return False
# Basic divisibility check
if tile_m % (warp_m * warp_tile_m) != 0:
return False
if tile_n % (warp_n * warp_tile_n) != 0:
return False
if tile_k % (warp_k * warp_tile_k) != 0:
return False
return True
else:
# Validate preshuffle specific constraints
if self.config.get("permute_n"):
valid = (tile_n / warp_tile_n / warp_n) % 2 == 0
if not valid:
return False
# Full validation for generation
# Determine data types for validation
a_datatype = self.datatype
b_datatype = self.datatype
c_datatype = self.datatype
layout = self.layout
# Special handling for certain data types
if self.datatype in ["fp8", "bf8"]:
c_datatype = "fp16"
# Use the comprehensive validation function
return is_tile_config_valid(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
a_datatype,
b_datatype,
c_datatype,
pipeline,
layout,
self.gpu_target,
)
def _generate_kernel_instance(
self, tile_config, trait_combo, k_block_per_cu, permute_n, is_header=True
):
"""Generate a single kernel instance"""
(
pipeline,
epilogue,
scheduler,
pad_m,
pad_n,
pad_k,
persistent,
) = trait_combo
# Create kernel name with proper boolean capitalization
kernel_name = f"gemm_preshuffle_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
# Create tile configuration string
tile_str = (
f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
)
tile_str += (
f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
)
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
kernel_name += f"_{tile_str}"
# Map pipeline names to the correct pipeline implementation
pipeline_impl_map = {
"preshufflev2": "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2",
}
# Map pipeline names to base pipeline for hot loop detection
base_pipeline_map = {
"preshufflev2": "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2",
}
# Map scheduler names to the correct enum values
scheduler_type_map = {
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
"interwave": "ck_tile::GemmPipelineScheduler::Interwave",
"default": "ck_tile::GemmPipelineScheduler::Default",
}
# Determine accumulator type based on datatype
acc_type = "float"
# Determine output type
c_type = self.datatype
if self.datatype in ["fp8", "bf8"]:
c_type = "fp16"
# Determine layouts based on self.layout
a_layout, b_layout, c_layout = get_abc_layouts(self.layout)
# Generate kernel instance code using the correct API
pragma_line = "#pragma once\n" if is_header else ""
instance_code = f"""// Generated kernel instance for {kernel_name}
{pragma_line}
#include <cstdint>
#include <utility>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
using ADataType = {get_dtype_string(self.datatype)};
using BDataType = {get_dtype_string(self.datatype)};
using AccDataType = {acc_type};
using CDataType = {get_dtype_string(c_type)};
using ALayout = {a_layout};
using BLayout = {b_layout};
using CLayout = {c_layout};
// Kernel name for display
constexpr const char* KERNEL_NAME = "{kernel_name}";
// Wrapper for simplified launch interface
struct SelectedKernel {{
// Tile configuration
static constexpr ck_tile::index_t BlockSize = 256;
static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]};
static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]};
static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]};
static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]};
static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]};
static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]};
static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]};
static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]};
static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]};
// Traits
static constexpr bool kPadM = {"true" if pad_m == "true" else "false"};
static constexpr bool kPadN = {"true" if pad_n == "true" else "false"};
static constexpr bool kPadK = {"true" if pad_k == "true" else "false"};
static constexpr bool TransposeC = false;
static constexpr bool UsePersistentKernel = {"true" if persistent == "true" else "false"};
static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "preshufflev2" else "false"};
static constexpr bool UseStructuredSparsity = false;
static constexpr bool Preshuffle = true;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool PermuteN = {"true" if permute_n else "false"};
// Tile shape
using TileShape = ck_tile::TileGemmShape<
ck_tile::sequence<TileM, TileN, TileK>,
ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N, WarpPerBlock_K>,
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>,
false, false>;
// Tile partitioner
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<TileShape, 8, 4>;
// Traits
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, NumWaveGroups>;
// Pipeline problem
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
Traits>;
// Base pipeline for hot loop detection
using BaseGemmPipeline = {base_pipeline_map.get(pipeline, "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2")}<GemmPipelineProblem>;
static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
constexpr auto scheduler = {scheduler_type_map.get(scheduler, "ck_tile::GemmPipelineScheduler::Default")};
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC,
UseStructuredSparsity, UsePersistentKernel,
NumWaveGroups, Preshuffle>,
scheduler>;
using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2")}<UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {{
[[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;
// Epilogue
"""
# Add epilogue configuration based on type
if epilogue == "cshuffle":
instance_code += """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock, // kM_
TilePartitioner::NPerBlock, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC, // isCTransposed_
memory_operation, // MemoryOperation_
NumWaveGroups, // kNumWaveGroups_
false, // FixedVectorSize_
1, // VectorSizeC_
PermuteN>; // isPermuteN_
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;
"""
else: # default epilogue
instance_code += """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock, // kM_
TilePartitioner::NPerBlock, // kN_
kPadM,
kPadN,
WarpTileM, // kMPerXdl_
WarpTileN, // kNPerXdl_
WarpTileK, // kKPerXdl_
TransposeC>; // isCTransposed_
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;
"""
instance_code += f"""
// Kernel type
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
// Make kernel arguments
auto kargs = GemmKernel::MakeKernelArgs(args);
if (!GemmKernel::IsSupportedArgument(kargs)) {{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}}
// Get grid and block sizes
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent == "true" else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
const dim3 blocks = GemmKernel::BlockSize();
if(stream.log_level_ > 0) {{
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n'
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
<< std::endl;
}}
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
return ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
}};
if(args.k_batch == 1) {{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}});
}} else {{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{{}});
}}
}}
}};
"""
return kernel_name, instance_code
def run(self, num_workers=None):
"""Run the builder to generate individual kernel files"""
# Generate individual kernel files
self.generate_individual(num_workers)
def generate_individual(self, num_workers=None):
"""Generate individual kernel files for separate compilation with parallel processing"""
if num_workers is None:
num_workers = min(
multiprocessing.cpu_count(), 8
) # Limit to avoid memory issues
tile_configs = self._get_tile_configs()
trait_combos = self._generate_trait_combinations()
k_block_per_cu = self.config.get("k_block_per_cu")
permute_n = self.config.get("permute_n")
# Prepare work items for parallel processing
work_items = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
work_items.append(
(
tile_config,
trait_combo,
k_block_per_cu,
permute_n,
self.working_path,
self.datatype,
self.layout,
)
)
print(
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
)
print(f" Tile configs: {len(tile_configs)}")
print(f" Trait combinations: {len(trait_combos)}")
print(f" Total kernels: {len(work_items)}")
# Show first few work items for debugging
if work_items:
print(" First work item example:")
tile_config, trait_combo = work_items[0][:2]
print(f" Tile config: {tile_config}")
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
# Process work items in parallel
kernel_list = []
completed = 0
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
# Submit all work items
print(f" Submitting {len(work_items)} tasks to executor...")
future_to_item = {
executor.submit(_generate_single_kernel_individual, item): item
for item in work_items
}
print(" All tasks submitted, waiting for completion...")
# Collect results with progress reporting
for future in concurrent.futures.as_completed(future_to_item):
completed += 1
if completed % 100 == 0 or completed == len(work_items):
print(
f" Progress: {completed}/{len(work_items)} kernels generated"
)
try:
result = future.result()
if result:
kernel_list.append(result)
except Exception as exc:
item = future_to_item[future]
print(f"Kernel generation failed for {item}: {exc}")
# Sort kernel list for consistent ordering
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
# Generate CMake include file for individual targets
self._generate_cmake_individual_targets(kernel_list)
print(
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
)
def _generate_cmake_individual_targets(self, kernel_list):
"""Generate CMake include file that creates individual targets"""
cmake_code = f"""# Generated CMake file for individual GEMM Preshuffle targets
# Datatype: {self.datatype}, Layout: {self.layout}
"""
for kernel_name, trait_combo, tile_config in kernel_list:
pipeline, epilogue, scheduler = trait_combo[:3]
# Format tile config for CMake function
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
trait_str = f"{pipeline}_{epilogue}_{scheduler}_" + "_".join(
str(x) for x in trait_combo[3:]
)
cmake_code += f'create_individual_gemm_preshuffle_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n'
# Write CMake include file
with open(
self.working_path / "gemm_preshuffle_individual_targets.cmake", "w"
) as f:
f.write(cmake_code)
def _generate_single_kernel_individual(work_item):
"""Worker function to generate a single individual kernel file"""
(
tile_config,
trait_combo,
k_block_per_cu,
permute_n,
working_path,
datatype,
layout,
) = work_item
# Create a temporary builder instance for this worker
builder = GemmPreshuffleKernelBuilder(working_path, datatype, layout)
try:
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo, k_block_per_cu, permute_n
)
# Create simplified filename without the "gemm_preshuffle_" prefix
# Remove "gemm_preshuffle_" from the beginning of kernel_name for the filename
simplified_name = kernel_name
if simplified_name.startswith("gemm_preshuffle_"):
simplified_name = simplified_name[16:] # Remove "gemm_preshuffle_" prefix
# Write individual header file
header_file = working_path / f"gemm_single_{simplified_name}.hpp"
with open(header_file, "w") as f:
f.write(instance_code)
return (kernel_name, trait_combo, tile_config)
except Exception as e:
print(f"Error generating individual kernel: {e}")
return None
def main():
parser = argparse.ArgumentParser(
description="GEMM kernel instance builder with parallel support"
)
parser.add_argument("--working_path", required=True, help="Working directory path")
parser.add_argument(
"--gpu_target",
required=True,
help="GPU target architecture",
)
parser.add_argument(
"--datatype",
required=True,
choices=["fp16", "fp8", "bf16", "bf8"],
help="Data type",
)
parser.add_argument(
"--layout",
required=True,
choices=["rcr"],
help="Matrix layout",
)
parser.add_argument("--config_json", required=True, help="Configuration JSON file")
parser.add_argument(
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
)
parser.add_argument(
"--gen_all_individual",
action="store_true",
help="Generate individual kernel files",
)
parser.add_argument(
"--gen_single", action="store_true", help="Generate a single kernel file"
)
parser.add_argument("--kernel_name", help="Kernel name for single generation")
parser.add_argument(
"--tile_config", help="Tile configuration string for single generation"
)
parser.add_argument(
"--trait_combo", help="Trait combination string for single generation"
)
parser.add_argument(
"--list_kernels",
action="store_true",
help="List kernel configurations without generating files",
)
args = parser.parse_args()
assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], (
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])"
)
layout_parts = args.layout.lower()
assert len(layout_parts) == 3, (
f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
)
assert layout_parts[0] in ["r"] and layout_parts[1] in ["c"], (
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a must be 'r' for row major and matrix_b must be 'c' for column major as it is the only supported layout for preshuffle)"
)
assert layout_parts[2] == "r", (
f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
)
# Create builder
builder = GemmPreshuffleKernelBuilder(
args.working_path, args.gpu_target, args.datatype, args.layout, args.config_json
)
if args.list_kernels:
# Fast listing mode - just write kernel list without generating files
builder.write_kernel_list()
pass
elif args.gen_single:
# Generate a single kernel file
if not args.kernel_name or not args.tile_config or not args.trait_combo:
parser.error(
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
)
# Parse tile config
tile_parts = args.tile_config.split("_")
tile_dims = tile_parts[0].split("x")
warp_dims = tile_parts[1].split("x")
warp_tile_dims = tile_parts[2].split("x")
tile_config = {
"tile_m": int(tile_dims[0]),
"tile_n": int(tile_dims[1]),
"tile_k": int(tile_dims[2]),
"warp_m": int(warp_dims[0]),
"warp_n": int(warp_dims[1]),
"warp_k": int(warp_dims[2]),
"warp_tile_m": int(warp_tile_dims[0]),
"warp_tile_n": int(warp_tile_dims[1]),
"warp_tile_k": int(warp_tile_dims[2]),
}
# Parse trait combo
trait_parts = args.trait_combo.split("_")
trait_combo = (
trait_parts[0], # pipeline
trait_parts[1], # epilogue
trait_parts[2], # scheduler
trait_parts[3] == "True", # pad_m
trait_parts[4] == "True", # pad_n
trait_parts[5] == "True", # pad_k
trait_parts[6] == "True", # persistent
)
k_block_per_cu = builder.config.get("k_block_per_cu")
permute_n = builder.config.get("permute_n")
# Generate the kernel
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo, k_block_per_cu, permute_n
)
# Write the file
simplified_name = kernel_name
if simplified_name.startswith("gemm_preshuffle_"):
simplified_name = simplified_name[16:]
header_file = (
builder.working_path / f"gemm_preshuffle_single_{simplified_name}.hpp"
)
with open(header_file, "w") as f:
f.write(instance_code)
print(f"Generated {header_file}")
elif args.gen_all_individual:
# Generate all individual kernel files
builder.run(args.num_workers)
pass
else:
parser.error(
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
)
if __name__ == "__main__":
main()