From 06fad18aaf3d0fe158c88cb8ee7bedc75185a72f Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Sat, 30 Aug 2025 09:54:18 -0400 Subject: [PATCH] Restructure the Tile Engine to have faster build time and clear config report (#2747) * Making edits to identify individual compilation issues. * Minor fix for blob txt files not being created. * Fixing compilation issues. * Fixing ordering bug. * Adding python profiling functionality. * Setting individual build as default. * Setting gpu target filtering for tile engine to gfx90a, gfx942 and gfx950. * update the default running parameters and settings * Fixing bug with benchmarking, shifting file generation to build instead of config. * Updating fixes. * Fixing json output and parsing. * Disable ccache for tile engine gemm ops because we dont need it. * Removing duplicate type definition. * Improving json printing. * Add the flexibility of different layout and more warp tile support * Fix extra flag in name of individual kernels. * Fixing bug with booleans. * Solve the first patch of the post merge conflict * Compilation fixes, and cosmetic improvements. * Yet again compilation fixes after latest changes from develop. * Fixing python benchmarking script. --------- Co-authored-by: Vidyasagar Ananthan Co-authored-by: Vidyasagar Ananthan [ROCm/composable_kernel commit: 705804d9bf87e1e2fca23c0af231efcdebf76efb] --- script/cmake-ck-dev.sh | 13 +- tile_engine/ops/gemm/CMakeLists.txt | 418 +++-- tile_engine/ops/gemm/README.md | 495 ++++- tile_engine/ops/gemm/benchmark_gemm.cpp | 68 - tile_engine/ops/gemm/benchmark_gemm.hpp | 19 +- .../ops/gemm/benchmark_gemm_single.cpp | 160 ++ tile_engine/ops/gemm/codegen_utils.py | 8 + tile_engine/ops/gemm/configs/benchmark.json | 12 +- .../ops/gemm/configs/default_config.json | 200 +- tile_engine/ops/gemm/gemm_benchmark.py | 721 ++++++++ tile_engine/ops/gemm/gemm_common.hpp | 197 ++ tile_engine/ops/gemm/gemm_host_api.hpp | 223 --- tile_engine/ops/gemm/gemm_instance_builder.py | 1612 +++++++++-------- tile_engine/ops/gemm/gemm_profiler.hpp | 37 +- tile_engine/ops/gemm/test_benchmark.sh | 102 ++ tile_engine/ops/gemm/test_validation.py | 143 ++ tile_engine/ops/gemm/validation_utils.py | 342 ++++ 17 files changed, 3361 insertions(+), 1409 deletions(-) delete mode 100644 tile_engine/ops/gemm/benchmark_gemm.cpp create mode 100644 tile_engine/ops/gemm/benchmark_gemm_single.cpp create mode 100755 tile_engine/ops/gemm/gemm_benchmark.py create mode 100644 tile_engine/ops/gemm/gemm_common.hpp delete mode 100644 tile_engine/ops/gemm/gemm_host_api.hpp mode change 100755 => 100644 tile_engine/ops/gemm/gemm_instance_builder.py create mode 100755 tile_engine/ops/gemm/test_benchmark.sh create mode 100644 tile_engine/ops/gemm/test_validation.py create mode 100644 tile_engine/ops/gemm/validation_utils.py diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index b93555901e..217ec998bd 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -25,13 +25,20 @@ if [ $# -ge 1 ]; then GPU_TARGETS=$1 shift 1 echo "GPU targets provided: $GPU_TARGETS" + REST_ARGS=("$@") ;; *) - echo "No GPU targets provided, using default targets: $GPU_TARGETS" + echo "No GPU targets provided, using default targets: gfx908;gfx90a;gfx942" + GPU_TARGETS="gfx908;gfx90a;gfx942" + shift 1 + REST_ARGS=("$@") ;; esac else - echo "No GPU targets provided, using default targets: $GPU_TARGETS" + echo "No GPU targets provided, using default targets: gfx908;gfx90a;gfx942" + GPU_TARGETS="gfx908;gfx90a;gfx942" + shift 1 + REST_ARGS=("$@") fi cmake \ @@ -43,5 +50,5 @@ cmake -D GPU_TARGETS=$GPU_TARGETS \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ -$@ \ +"${REST_ARGS[@]}" \ \ ${MY_PROJECT_SOURCE} diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index 42c114b499..d52351af2d 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -1,169 +1,295 @@ - set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)") set(GEMM_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)") +set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") +option(ENABLE_CCACHE_GEMM "Enable ccache for GEMM ops compilation" OFF) -function(build_gemm_for_datatype datatype layout) - # Filter GPU targets to only gfx90a, gfx942, and gfx950 - set(GEMM_GPU_TARGETS "") - set(DESIRED_TARGETS "gfx90a;gfx942;gfx950") - - foreach(target IN LISTS SUPPORTED_GPU_TARGETS) - if(target IN_LIST DESIRED_TARGETS) - list(APPEND GEMM_GPU_TARGETS ${target}) - endif() - endforeach() - - # Skip compilation if no matching targets found - if(NOT GEMM_GPU_TARGETS) - message(WARNING "Skipping Tile Engine GEMM compilation: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") +# Store the directory path for use in functions +set(GEMM_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) + +# Function to create individual GEMM targets +function(create_individual_gemm_target datatype layout trait tile_config config_json) + # Use the parent scope GEMM_GPU_TARGETS_INDIVIDUAL variable + if(NOT GEMM_GPU_TARGETS_INDIVIDUAL) + message(WARNING "Skipping individual GEMM target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets") return() endif() - message(STATUS "Building GEMM for GPU targets: ${GEMM_GPU_TARGETS}") + # Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k + # First split by underscore to get three groups + string(REPLACE "_" ";" config_groups ${tile_config}) + list(GET config_groups 0 tile_dims) # e.g., 256x256x32 + list(GET config_groups 1 warp_dims) # e.g., 4x1x1 + list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16 + # Parse tile dimensions + string(REPLACE "x" ";" tile_parts ${tile_dims}) + list(GET tile_parts 0 tile_m) + list(GET tile_parts 1 tile_n) + list(GET tile_parts 2 tile_k) + + # Parse warp dimensions + string(REPLACE "x" ";" warp_parts ${warp_dims}) + list(GET warp_parts 0 warp_m) + list(GET warp_parts 1 warp_n) + list(GET warp_parts 2 warp_k) + + # Parse warp tile dimensions + string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims}) + list(GET warp_tile_parts 0 warp_tile_m) + list(GET warp_tile_parts 1 warp_tile_n) + list(GET warp_tile_parts 2 warp_tile_k) + + set(target_name "benchmark_gemm_${datatype}_${layout}_${trait}_${tile_config}") set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") - - # Comment this if-else block when using user_provided_config - if(layout STREQUAL "rcr") - set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") - else() - set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/custom_ci_config.json") - endif() - - # uncomment this if you want to use user_provided_config.json - # set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json") - # Generate kernel list + # Generate the single instance header for this kernel + set(instance_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") + + # Add custom command to generate the header file at build time + add_custom_command( + OUTPUT ${instance_header} + COMMAND ${Python3_EXECUTABLE} ${GEMM_SOURCE_DIR}/gemm_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${config_json} + --gen_single + --kernel_name "gemm_${datatype}_${layout}_${trait}_${tile_config}" + --tile_config "${tile_config}" + --trait_combo "${trait}" + DEPENDS ${GEMM_SOURCE_DIR}/gemm_instance_builder.py ${config_json} + COMMENT "Generating ${instance_header}" + ) + + # Create the executable + add_executable(${target_name} + ${GEMM_SOURCE_DIR}/benchmark_gemm_single.cpp + ${instance_header} + ) + + # Set GPU architectures + set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS_INDIVIDUAL}) + + # Set compile definitions + target_compile_definitions(${target_name} PRIVATE + GEMM_SINGLE_INSTANCE_HPP="${instance_header}" + ) + + # Include directories + target_include_directories(${target_name} PRIVATE + ${GEMM_SOURCE_DIR} + ${working_path} + ) + + # Compile options + target_compile_options(${target_name} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + -include ${instance_header} + ) + + # Add to collection targets + add_dependencies(benchmark_gemm_all ${target_name}) + add_dependencies(benchmark_gemm_${datatype} ${target_name}) + add_dependencies(benchmark_gemm_${layout} ${target_name}) + add_dependencies(benchmark_gemm_${datatype}_${layout} ${target_name}) + + # Add to trait-specific targets + string(REPLACE "_" ";" trait_parts ${trait}) + list(GET trait_parts 0 pipeline) + list(GET trait_parts 1 epilogue) + list(GET trait_parts 2 scheduler) + + add_dependencies(benchmark_gemm_${pipeline} ${target_name}) + add_dependencies(benchmark_gemm_${epilogue} ${target_name}) + add_dependencies(benchmark_gemm_${scheduler} ${target_name}) +endfunction() + +# Function to build individual GEMM targets +function(build_individual_gemm_targets datatype layout) + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") + + # Choose config file + # Priority order: + # 1. Environment variable GEMM_CONFIG_FILE + # 2. CMake variable GEMM_CONFIG_FILE + # 3. Default based on layout + + # Check environment variable first + if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "") + set(config_filename "$ENV{GEMM_CONFIG_FILE}") + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}") + message(STATUS " Using config from environment variable: ${config_filename}") + elseif(NOT "${GEMM_CONFIG_FILE}" STREQUAL "") + # Use CMake variable if set + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_CONFIG_FILE}") + message(STATUS " Using custom config: ${GEMM_CONFIG_FILE}") + else() + # Use default config for all layouts + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") + message(STATUS " Using default config for layout ${layout}") + endif() + + # Check if config file exists + if(NOT EXISTS ${json_blob}) + message(FATAL_ERROR "Config file not found: ${json_blob}") + endif() + + # Determine number of workers for parallel generation + if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL}) + set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL}) + else() + # Use processor count but limit to avoid memory issues + cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES) + math(EXPR num_workers "${num_cores}") + if(num_workers GREATER 8) + set(num_workers 8) + endif() + endif() + + # Generate individual kernel files using parallel version + message(STATUS "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...") + message(STATUS " Working path: ${working_path}") + message(STATUS " Config file: ${json_blob}") + message(STATUS " Python executable: ${Python3_EXECUTABLE}") + message(STATUS " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py") + + # Create working directory first + file(MAKE_DIRECTORY ${working_path}) + + # First, just list the kernels (fast operation) + message(STATUS " Listing kernel configurations...") execute_process( - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py --working_path ${working_path} --datatype ${datatype} --layout ${layout} --config_json ${json_blob} - --list_blobs + --list_kernels + WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} RESULT_VARIABLE ret + OUTPUT_VARIABLE list_output + ERROR_VARIABLE list_error ) + if(NOT ret EQUAL 0) - message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${ret}") + message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}") endif() - - file(STRINGS "${working_path}/gemm_instance_blobs.txt" codegen_blobs) - file(STRINGS "${working_path}/gemm_instance_blobs_range.txt" codegen_blobs_range) - # Generate the blobs - add_custom_command( - OUTPUT ${codegen_blobs} - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py - --working_path "${working_path}" - --datatype ${datatype} - --layout ${layout} - --config_json "${json_blob}" - --gen_blobs - COMMENT "Generating GEMM instance sources for ${datatype} ${layout}" - ) - add_custom_target(gemm_gen_${datatype}_${layout} DEPENDS ${codegen_blobs}) - - set(intermediate_libs) - list(LENGTH codegen_blobs codegen_blobs_len) - - foreach(blob IN LISTS codegen_blobs_range) - string(STRIP "${blob}" stripped_blob) - separate_arguments(spilit_blob UNIX_COMMAND "${stripped_blob}") - # Each line is: - list(GET spilit_blob 0 name) - list(GET spilit_blob 1 first) - list(GET spilit_blob 2 last) - math(EXPR total_files "${last} - ${first}") - if(total_files EQUAL 0) - continue() # nothing for this trait - endif() - - # Object libraries (chunked) per trait - set(sub_intermediate_libs) - set(chunk_size 3) - math(EXPR num_chunks "( ${total_files} + ${chunk_size} - 1 ) / ${chunk_size}") - math(EXPR num_chunks_minus_1 "${num_chunks} - 1") - - foreach(i RANGE 0 ${num_chunks_minus_1}) - math(EXPR start "${first} + ${i} * ${chunk_size} ") - math(EXPR end "${start} + ${chunk_size} - 1") - - set(chunk_files) - foreach(j RANGE ${start} ${end}) - if(j LESS ${last} AND j LESS ${codegen_blobs_len}) - list(GET codegen_blobs ${j} f) - list(APPEND chunk_files "${f}") - endif() - endforeach() - - #list(LENGTH chunk_files chunk_files_len) - #if(chunk_files_len AND chunk_files_len GREATER 1) - if(chunk_files) - set(sub_intermediate_lib_name "gemm_objlib_${name}_${i}_${datatype}_${layout}") - add_library(${sub_intermediate_lib_name} OBJECT ${chunk_files}) - set_property(TARGET ${sub_intermediate_lib_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS}) - list(APPEND sub_intermediate_libs ${sub_intermediate_lib_name}) - endif() - + # Read kernel count + if(EXISTS ${working_path}/gemm_kernel_count.txt) + file(READ ${working_path}/gemm_kernel_count.txt kernel_count) + string(STRIP "${kernel_count}" kernel_count) + message(STATUS " Found ${kernel_count} kernel configurations") + else() + message(FATAL_ERROR "Kernel count file not found") + endif() + + # Read kernel list and create targets + if(EXISTS ${working_path}/gemm_kernel_list.txt) + file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) + foreach(line IN LISTS kernel_lines) + # Parse line: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) + + # Create individual target + create_individual_gemm_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}") endforeach() - - # ------------------ Bundle the object libs into one static lib --------- - #list(LENGTH sub_intermediate_libs sub_intermediate_libs_len) - #if(sub_intermediate_libs AND sub_intermediate_libs_len GREATER 1) - if(sub_intermediate_libs) - set(intermediate_lib_name "gemm_staticlib_${name}_${datatype}_${layout}") - # Collect the $ expressions - - set(obj_exprs) - foreach(objlib IN LISTS sub_intermediate_libs) - list(APPEND obj_exprs $) - endforeach() - - add_library(${intermediate_lib_name} STATIC ${obj_exprs}) - add_dependencies(${intermediate_lib_name} gemm_gen_${datatype}_${layout}) - set_property(TARGET ${intermediate_lib_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS}) - #foreach(objlib IN LISTS sub_intermediate_libs) - # target_sources(${intermediate_lib_name} PRIVATE $) - #endforeach() - list(APPEND intermediate_libs ${intermediate_lib_name}) - endif() - - endforeach() - - # Interface library for instances - add_library(gemm_template_instances_${datatype}_${layout} INTERFACE) - add_dependencies(gemm_template_instances_${datatype}_${layout} gemm_gen_${datatype}_${layout}) - target_link_libraries(gemm_template_instances_${datatype}_${layout} INTERFACE ${intermediate_libs}) - target_include_directories(gemm_template_instances_${datatype}_${layout} INTERFACE - ${CMAKE_CURRENT_LIST_DIR} - "${working_path}" - ) - set_target_properties(gemm_template_instances_${datatype}_${layout} PROPERTIES LINKER_LANGUAGE CXX) - - # Host API interface library - add_library(gemm_host_api_${datatype}_${layout} INTERFACE) - target_link_libraries(gemm_host_api_${datatype}_${layout} INTERFACE gemm_template_instances_${datatype}_${layout}) - target_include_directories(gemm_host_api_${datatype}_${layout} INTERFACE - ${CMAKE_CURRENT_LIST_DIR} - "${working_path}" - ) - - - # Executable per datatype - set(exec_name "benchmark_gemm_${datatype}_${layout}") - add_executable(${exec_name} benchmark_gemm.cpp) - set_property(TARGET ${exec_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS}) - target_link_libraries(${exec_name} PRIVATE gemm_host_api_${datatype}_${layout}) - target_compile_options(${exec_name} PRIVATE - -Wno-undefined-func-template - -Wno-float-equal - --offload-compress - ) + else() + message(FATAL_ERROR "Kernel list file not found") + endif() endfunction() -# Process each datatype in isolation -foreach(dt IN LISTS GEMM_DATATYPE) - foreach(l IN LISTS GEMM_LAYOUT) - build_gemm_for_datatype(${dt} ${l}) - endforeach() +# Main build logic - Only individual builds supported +message(STATUS "=== Starting Tile Engine GEMM Configuration ===") +message(STATUS "GEMM_DATATYPE: ${GEMM_DATATYPE}") +message(STATUS "GEMM_LAYOUT: ${GEMM_LAYOUT}") +message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + +# Filter GPU targets to only gfx90a, gfx942, and gfx950 +set(GEMM_GPU_TARGETS_INDIVIDUAL "") +set(DESIRED_TARGETS "gfx90a;gfx942;gfx950") + +foreach(target IN LISTS SUPPORTED_GPU_TARGETS) + if(target IN_LIST DESIRED_TARGETS) + list(APPEND GEMM_GPU_TARGETS_INDIVIDUAL ${target}) + message(STATUS " Adding GPU target: ${target}") + endif() endforeach() + +# Skip build if no matching targets found +if(NOT GEMM_GPU_TARGETS_INDIVIDUAL) + message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") +else() + message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}") + + # Enable parallel compilation optimizations + # Set up job pools for better parallel compilation control + set_property(GLOBAL PROPERTY JOB_POOLS + compile_heavy=4 # Limit heavy compilations to prevent OOM + compile_normal=16 # Allow more parallel normal compilations + ) + + # Enable compiler cache if available and explicitly requested + # Disabled by default due to permission issues in CI environments + if(ENABLE_CCACHE_GEMM) + find_program(CCACHE_PROGRAM ccache) + if(CCACHE_PROGRAM) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) + message(STATUS "Using ccache for faster compilation") + else() + message(WARNING "ccache requested but not found") + endif() + else() + message(STATUS "ccache disabled for GEMM ops (use -DENABLE_CCACHE_GEMM=ON to enable)") + endif() + + # Create master collection targets + add_custom_target(benchmark_gemm_all) + + # Create datatype collection targets + foreach(dt IN LISTS GEMM_DATATYPE) + add_custom_target(benchmark_gemm_${dt}) + endforeach() + + # Create layout collection targets + foreach(l IN LISTS GEMM_LAYOUT) + add_custom_target(benchmark_gemm_${l}) + endforeach() + + # Create combined collection targets + foreach(dt IN LISTS GEMM_DATATYPE) + foreach(l IN LISTS GEMM_LAYOUT) + add_custom_target(benchmark_gemm_${dt}_${l}) + endforeach() + endforeach() + + # Create trait-based collection targets + # These are common trait components used across all GEMM kernels + set(GEMM_PIPELINES "mem;compv3;compv4") + set(GEMM_EPILOGUES "default;cshuffle") + set(GEMM_SCHEDULERS "intrawave;interwave") + + foreach(pipeline IN LISTS GEMM_PIPELINES) + add_custom_target(benchmark_gemm_${pipeline}) + endforeach() + + foreach(epilogue IN LISTS GEMM_EPILOGUES) + add_custom_target(benchmark_gemm_${epilogue}) + endforeach() + + foreach(scheduler IN LISTS GEMM_SCHEDULERS) + add_custom_target(benchmark_gemm_${scheduler}) + endforeach() + + # Build individual targets for each datatype/layout combination + foreach(dt IN LISTS GEMM_DATATYPE) + foreach(l IN LISTS GEMM_LAYOUT) + build_individual_gemm_targets(${dt} ${l}) + endforeach() + endforeach() +endif() diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md index 79152a1a0d..01ffbb6da7 100644 --- a/tile_engine/ops/gemm/README.md +++ b/tile_engine/ops/gemm/README.md @@ -1,113 +1,442 @@ -# GEMM Matrix Multiplication +# CK Tile Engine GEMM Operations -CK Tile Engine GEMM is used to generate and run GEMM kernels with different combinations of BlockTile sizes, WarpTile sizes, WarpTile mapping for all valid pipelines, schedulers and epilogues. +## Overview -# Kernel Configurations +The CK Tile Engine GEMM module provides a comprehensive system for generating, building, and benchmarking GEMM (General Matrix Multiplication) kernels with various configurations. It supports multiple data types, layouts, and optimization strategies. The system has evolved from a monolithic build approach (where all kernels compile into a single executable) to a more flexible individual kernel compilation system, providing better build parallelism and targeted testing capabilities. -Users can specify custom kernel configurations such as tile size, warp size, padding, pipeline, scheduler, and epilogue in the config file. This allows building only for selected configurations, significantly reducing build time. -For reference please see `./configs/user_provided_config.json`. +## Table of Contents +1. [Build System Architecture](#build-system-architecture) +2. [Build Instructions](#build-instructions) +3. [Running Benchmarks](#running-benchmarks) +4. [Configuration System](#configuration-system) +5. [Scripts and Tools](#scripts-and-tools) +6. [Command Line Options](#command-line-options) +7. [Understanding Kernel Names](#understanding-kernel-names) +8. [Troubleshooting](#troubleshooting) +9. [Performance Tips](#performance-tips) -The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark. For reference please see in `./configs/default_config.json` +## Build System Architecture -If user does not provide kernel configuration, the tile engine uses default kernel configuration to generate kernel instances and benchmark. +### Individual Kernel Compilation (New Approach) + +The new tile engine benchmark system compiles each kernel configuration into a separate executable. This provides: +- Better build parallelism +- Faster incremental builds +- More targeted testing +- Easier debugging of specific configurations + +Each benchmark executable follows the naming pattern: +``` +benchmark_gemm____ +``` + +### Monolithic Build (Legacy Approach) + +The original system compiles all kernels into a single executable (`benchmark_gemm_[Datatype]_[Layout]`), which can then be filtered at runtime using command-line arguments. ## Build Instructions -``` bash -# in the root of composable kernel create build directory + +### Prerequisites +- ROCm installation +- CMake 3.16 or higher +- C++17 compatible compiler + +### Basic Build + +```bash +# In the root of composable kernel, create build directory mkdir build && cd build -# build composable kernel -# replace [Arch] with the appropriate architecture or leave blank and -# replace [Datatype1;Datatype2;...] in comma separated datatypes string (possible datatypes are [fp8, bf8, int8, fp16, bf16]) -# replace [Layout1;Layout2;...] in comma separated datatypes string (possible layouts are [rcr, rrr, crr, ccr]) -../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]" -DGEMM_LAYOUT="[Layout1;Layout2]" -# generate different executable for each passed datatype + +# Configure with specific datatypes and layouts +# Replace [Arch] with your GPU architecture (e.g., gfx90a, gfx942) +# Replace [Datatype1;Datatype2;...] with datatypes (fp8, bf8, int8, fp16, bf16, fp32, fp64) +# Replace [Layout1;Layout2;...] with layouts (rcr, rrr, crr, ccr) +../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]" -DGEMM_LAYOUT="[Layout1;Layout2]" + +# Build specific benchmarks make benchmark_gemm_[Datatype1]_[Layout1] -j -make benchmark_gemm_[Datatype1]_[Layout2] -j -make benchmark_gemm_[Datatype2]_[Layout1] -j -make benchmark_gemm_[Datatype2]_[Layout2] -j -``` -`benchmark_gemm_[Datatype]_[Layout]` will be located in the `./bin/` directory. - -`benchmark_gemm_[Datatype]_[Layout]` must be rebuilt everytime if configuration file is modified. - -``` bash -rm -rf tile_engine/ && make benchmark_gemm_[Datatypes]_[Layout] -j # rebuild ``` -## For eaxmple build for gfx942 for fp8 and fp16 datatypes with rcr layout -``` bash +### Configuration Options + +The build system supports several configuration options: + +#### Using Custom Config Files +```bash +# Method 1: CMake variable (config file must be in configs/ directory) +cmake -DGEMM_CONFIG_FILE=my_custom_config.json ... + +# Method 2: Environment variable (takes precedence over CMake variable) +export GEMM_CONFIG_FILE=my_custom_config.json +cmake ... +``` + +#### Config File Priority Order +1. **Environment variable** `GEMM_CONFIG_FILE` (highest priority) +2. **CMake variable** `GEMM_CONFIG_FILE` +3. **Default config** (default_config.json for all layouts) + +**Note**: All custom config files must be placed in the `tile_engine/ops/gemm/configs/` directory. + +### Example Build Commands + +```bash +# Build for gfx942 with fp8 and fp16 datatypes, rcr layout mkdir build && cd build -../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16" -DGEMM_LAYOUT="rcr" +../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16" -DGEMM_LAYOUT="rcr;ccr;rrr;crr" make benchmark_gemm_fp8_rcr -j make benchmark_gemm_fp16_rcr -j ``` -## benchmark_gemm inputs +### Building Individual Kernels + +```bash +# Build a specific kernel configuration +make benchmark_gemm_fp8_rcr_compv4_default_intrawave_False_False_False_False_256x256x32_1x4x1_32x32x32 + +# Build all fp16 benchmarks in parallel +make -j$(nproc) $(make help | grep benchmark_gemm_fp16 | awk '{print $2}') ``` - -m The value for m dimension. Default is 3840. - -n The value for n dimension. Default is 4096. - -k The value for k dimension. Default is 2048. - -stride_a The stride value for tensor A. Default is 0. - -stride_b The stride value for tensor B. Default is 0. - -stride_c The stride value for tensor C Default is 0. - -split_k The split value for k dimension. Default is 1. - -v The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 for validation on GPU. Default is 2, validation on GPU. - -log Wether output kernel instance information or not. Possible values are true or false. Default is false. - -warmup The number of iterations before benchmark the kernel. Default is 50. - -repeat The number of iterations to benchmark the kernel. Default is 100. - -timer Whether if the timer is gpu timer or not. Possible values are true or false. Default is true. - -init The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 for constant(1). Default is 0, random. - -flush_cache To flush cache in between different runs.Possible values are true or false. Default is false. - -rotating_count count to flush cache. Default is 5. - -metric Metric with which to measure kernel performance. Set to 0 for latency, 1 for tflops, or 2 for bandwidth. Default is 0, latency. - -csv_filename The filename of benchmark result. Default is gemm_kernel. - -structured_sparsity whether use sparsity kernel or not. Possible values are true or false. Default is false. - -pipeline The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3. - -epilogue The type of epilogue. Possible values are cshuffle or default. Default is cshuffle. - -pad_m Whether pad or not in m direction. Possible values are true or false. Default is false. - -pad_n Whether pad or not in n direction. Possible values are true or false. Default is false. - -pad_k Whether pad or not in k direction. Possible values are true or false. Default is false. -Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in user_provided_config.json +### Rebuilding After Configuration Changes + +If you modify the configuration file, you must rebuild: +```bash +rm -rf tile_engine/ && make benchmark_gemm_[Datatype]_[Layout] -j ``` -Note: In `./configs/user_provided_config.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above. -## Example +## Running Benchmarks -The following JSON file specifies parameters used to generate and build GEMM kernels across all possible combinations of pipelines, schedulers, epilogues with different tile and warp sizes. +### Individual Kernel Execution + +```bash +cd /path/to/build/directory +./bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 \ + -m=512 -n=512 -k=512 -verify=1 +``` + +### Monolithic Executable (Legacy) + +```bash +# Run specific pipeline/scheduler/epilogue combination +./bin/benchmark_gemm_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=default +``` + +### Automated Testing + +Use the provided test script to run multiple benchmarks: +```bash +cd /path/to/composable_kernel/tile_engine/ops/gemm +./test_benchmark.sh [build_directory] +``` + +## Configuration System + +### Configuration Files + +The system uses JSON configuration files to specify kernel parameters: + +- `configs/default_config.json` - Default configurations for various datatypes +- `configs/user_provided_config.json` - User-customizable configurations + +### Configuration Structure ```json -{ - /// other parameters /// - - "tile_m": { - "values": [256] +{ + "tile_config": { + "tile_m": {"values": [256, 128]}, + "tile_n": {"values": [256, 128]}, + "tile_k": {"values": [64, 32]}, + "warp_m": {"values": [2, 4]}, + "warp_n": {"values": [2, 1]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [32, 16]}, + "warp_tile_n": {"values": [32, 16]}, + "warp_tile_k": {"values": [16, 32]} }, - "tile_n": { - "values": [256] - }, - "tile_k": { - "values": [64, 32] - }, - - /// other parameters /// - - "pipeline": { - "values": ["compv3", "compv4", "mem"] - }, - "scheduler": { - "values": ["intrawave", "interwave"] - }, - "epilogue": { - "values": ["default", "cshuffle"] + "trait_config": { + "pipeline": {"values": ["compv3", "compv4", "mem"]}, + "scheduler": {"values": ["intrawave", "interwave"]}, + "epilogue": {"values": ["default", "cshuffle"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [false]} } } ``` -At runtime, a specific subset of the generated kernels can be selected using command-line arguments. -``` bash -./bin/benchmark_gemm_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=default -``` -The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and default epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings. +## Scripts and Tools +### Python Scripts + +#### gemm_instance_builder.py +**Purpose**: Main kernel instance generation script that creates C++ kernel implementations based on configuration files. + +**Key Features**: +- Generates individual kernel header files for separate compilation +- Supports multiple data types (fp16, fp8, bf16, fp32, fp64) +- Validates tile configurations for correctness +- Creates CMake integration files + +**Usage**: +```bash +python gemm_instance_builder.py \ + --working_path ./generated \ + --datatype fp16 \ + --layout rcr \ + --config_json configs/user_provided_config.json \ + --gen_individual +``` + +#### gemm_instance_builder_parallel.py +**Purpose**: Parallel version of the instance builder for faster generation of multiple kernel configurations. + +**Features**: +- Multi-threaded kernel generation +- Improved performance for large configuration spaces + +#### validation_utils.py +**Purpose**: Provides comprehensive validation functions for kernel configurations. + +**Key Functions**: +- `is_tile_config_valid()` - Validates tile dimensions and alignments +- `is_trait_combination_valid()` - Checks if pipeline/epilogue/scheduler combinations are supported +- `validate_warp_tile_combination()` - GPU-specific warp tile validation +- `validate_lds_capacity()` - Ensures configurations fit in LDS memory + +**Validation Checks**: +- Dimension alignment (tile dimensions must be divisible by warp dimensions) +- LDS capacity constraints +- GPU-specific warp tile support +- Unsupported trait combinations + +#### test_validation.py +**Purpose**: Test suite for the validation logic to ensure correctness. + +**Usage**: +```bash +python test_validation.py +``` + +**Tests**: +- Warp tile combination validation +- Trait combination validation +- Full tile configuration validation + +#### gemm_benchmark.py +**Purpose**: Python script for running and analyzing GEMM benchmarks. + +**Features**: +- Automated benchmark execution +- Performance data collection +- Result analysis and reporting + +#### json_config.py +**Purpose**: Configuration file parsing and management. + +**Features**: +- JSON configuration loading +- Default configuration handling +- Configuration validation + +#### codegen_utils.py +**Purpose**: Utility functions for code generation. + +**Features**: +- Template processing +- Code formatting utilities +- File generation helpers + +### Shell Scripts + +#### test_benchmark.sh +**Purpose**: Automated benchmark testing script that finds and runs all built benchmark executables. + +**Features**: +- Automatic build directory detection +- Batch execution of multiple benchmarks +- CSV result collection +- Colored output for easy reading +- Example command generation + +**Usage**: +```bash +# Auto-detect build directory +./test_benchmark.sh + +# Specify build directory +./test_benchmark.sh /path/to/build/directory +``` + +**What it does**: +1. Finds all benchmark executables in the build directory +2. Runs each with multiple problem sizes (512, 1024, 2048) +3. Performs GPU verification +4. Saves results to timestamped CSV file +5. Provides summary statistics + +## Command Line Options + +All benchmark executables support the following options: + +### Matrix Dimensions +- `-m=` - M dimension (default: 3840) +- `-n=` - N dimension (default: 4096) +- `-k=` - K dimension (default: 2048) + +### Strides +- `-stride_a=` - Stride for matrix A (default: 0, auto-calculated) +- `-stride_b=` - Stride for matrix B (default: 0, auto-calculated) +- `-stride_c=` - Stride for matrix C (default: 0, auto-calculated) + +### Verification +- `-verify=<0|1|2>` - Verification mode + - 0: No verification (default) + - 1: CPU verification + - 2: GPU verification + +### Performance Testing +- `-warmup=` - Warmup iterations (default: 50) +- `-repeat=` - Benchmark iterations (default: 100) +- `-timer=` - Use GPU timer (default: true) +- `-flush_cache=` - Flush cache between runs (default: true) +- `-rotating_count=` - Cache rotation count (default: 1000) + +### Initialization +- `-init=<0|1|2>` - Tensor initialization method + - 0: Random values [-1, 1] (default) + - 1: Linear sequence (i % 17) + - 2: Constant value (1.0) + +### Output Options +- `-log=` - Enable verbose logging (default: false) +- `-metric=<0|1|2>` - Performance metric + - 0: Latency in ms (default) + - 1: TFLOPS + - 2: Bandwidth in GB/s +- `-json_output=` - JSON format output (default: false) +- `-csv_filename=` - Save results to CSV +- `-csv_format=` - CSV format (default: comprehensive) + +### Advanced Options +- `-split_k=` - Split-K factor (default: 1) +- `-structured_sparsity=` - Enable structured sparsity (default: false) +- `-pipeline=` - Pipeline type (default: compv3) +- `-scheduler=` - Scheduler type (default: intrawave) +- `-epilogue=` - Epilogue type (default: cshuffle) +- `-pad_m=` - Pad M dimension (default: false) +- `-pad_n=` - Pad N dimension (default: false) +- `-pad_k=` - Pad K dimension (default: false) +- `-persistent=` - Use persistent kernel (default: false) + +## Understanding Kernel Names + +The kernel naming convention encodes the configuration: + +``` +benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 + ^^^^ ^^^ ^^^^^^ ^^^^^^^ ^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^ ^^^^^^^ ^^^^^^^^^ + | | | | | | | | | + | | | | | Padding & flags | | Warp tile + | | | | Scheduler | Thread tile + | | | Epilogue Block tile + | | Pipeline + | Layout (Row-Column-Row) + Data type +``` + +### Components: +- **Data type**: fp16, fp32, bf16, fp8, bf8, int8 +- **Layout**: rcr (Row-Column-Row), rrr, crr, ccr +- **Pipeline**: mem, compv3, compv4 +- **Epilogue**: default, cshuffle +- **Scheduler**: intrawave, interwave +- **Flags**: pad_m, pad_n, pad_k, persistent (4 boolean flags) +- **Tile sizes**: BlockTile x ThreadTile x WarpTile + +## Troubleshooting + +### Common Issues + +1. **Kernel not found** + - Ensure the specific benchmark executable is built + - Check the build directory bin/ folder + +2. **Verification failures** + - Try GPU verification (-verify=2) which may be more accurate + - Check data type compatibility + - Verify stride calculations + +3. **Build failures** + - Check GPU architecture compatibility + - Ensure ROCm is properly installed + - Verify configuration file syntax + +4. **Performance variations** + - Increase warmup iterations + - Disable CPU frequency scaling + - Use GPU timer for accurate measurements + +### Debug Options + +Enable verbose logging: +```bash +./bin/benchmark_gemm_... -log=true -verify=1 +``` + +Test validation logic: +```bash +python test_validation.py +``` + +## Performance Tips + +1. **Optimal Problem Sizes**: Use sizes that are multiples of tile dimensions +2. **Warmup**: Use at least 50-100 warmup iterations +3. **GPU Timer**: Always use `-timer=true` for accurate measurements +4. **Cache Management**: Enable cache flushing for consistent results +5. **Thread Affinity**: Set CPU affinity to reduce variation + +## Integration Examples + +### Python Integration + +```python +import subprocess +import json + +# Run benchmark with JSON output +result = subprocess.run([ + './bin/benchmark_gemm_fp16_rcr_...', + '-m=1024', '-n=1024', '-k=1024', + '-json_output=true' +], capture_output=True, text=True) + +# Parse results +data = json.loads(result.stdout) +print(f"Performance: {data['tflops']} TFLOPS") +``` + +### Batch Testing Script + +```bash +#!/bin/bash +SIZES="512 1024 2048 4096" +for size in $SIZES; do + echo "Testing ${size}x${size}x${size}" + ./bin/benchmark_gemm_... -m=$size -n=$size -k=$size \ + -verify=2 -csv_filename=results.csv +done +``` + +## Contributing + +When adding new features or configurations: +1. Update validation logic in `validation_utils.py` +2. Add tests to `test_validation.py` +3. Update configuration examples +4. Document new command-line options + +For more information about the Composable Kernel project, visit the main repository documentation. diff --git a/tile_engine/ops/gemm/benchmark_gemm.cpp b/tile_engine/ops/gemm/benchmark_gemm.cpp deleted file mode 100644 index db2b648437..0000000000 --- a/tile_engine/ops/gemm/benchmark_gemm.cpp +++ /dev/null @@ -1,68 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include - -#include "gemm_profiler.hpp" -#include "benchmark_gemm.hpp" - -void benchmark_gemm(const ck_tile::ArgParser& arg_parser) -{ - GemmProblem gemm_problem{arg_parser.get_int("split_k"), - arg_parser.get_int("m"), - arg_parser.get_int("n"), - arg_parser.get_int("k"), - arg_parser.get_int("stride_a"), - arg_parser.get_int("stride_b"), - arg_parser.get_int("stride_c"), - DataTypeTraits::name, - DataTypeTraits::name, - DataTypeTraits::name, - DataTypeTraits::name, - ALayout::name, - BLayout::name, - CLayout::name, - arg_parser.get_bool("structured_sparsity")}; - - Setting setting{arg_parser.get_int("warmup"), - arg_parser.get_int("repeat"), - arg_parser.get_bool("timer"), - arg_parser.get_int("verify"), - arg_parser.get_int("init"), - arg_parser.get_bool("log"), - arg_parser.get_str("csv_filename"), - arg_parser.get_bool("flush_cache"), - arg_parser.get_int("rotating_count")}; - - auto& profiler = GemmProfiler::instance(setting); - - try - { - auto kernel_func = get_kernel_func_by_trait(arg_parser); - profiler.benchmark(gemm_problem, kernel_func); - profiler.select_best_instance(static_cast(arg_parser.get_int("metric"))); - } - catch(const std::exception& e) - { - std::cerr << "Benchmark failed: " << e.what() << std::endl; - } -} - -int main(int argc, char* argv[]) -{ - try - { - auto [result, parser] = create_args(argc, argv); - if(!result) - return EXIT_FAILURE; - benchmark_gemm(parser); - return 0; - } - catch(const std::exception& e) - { - std::cerr << "Error: " << e.what() << "\n"; - return EXIT_FAILURE; - } -} diff --git a/tile_engine/ops/gemm/benchmark_gemm.hpp b/tile_engine/ops/gemm/benchmark_gemm.hpp index ce8a6e8234..0e2619785e 100644 --- a/tile_engine/ops/gemm/benchmark_gemm.hpp +++ b/tile_engine/ops/gemm/benchmark_gemm.hpp @@ -7,8 +7,14 @@ #include #include #include +#include -#include "gemm_host_api.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "gemm_common.hpp" + +// Data types and Layouts are defined by the generated kernel headers +// No hardcoded type definitions here to avoid conflicts enum class Metric { @@ -55,8 +61,9 @@ struct GemmProblem << " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n" << " \"layout_a\":\"" << problem.layout_a_ << "\",\n" << " \"layout_b\":\"" << problem.layout_b_ << "\",\n" - << " \"layout_c\":\"" << problem.layout_c_ << "\"\n" - << " \"structured_sparsity\":\"" << problem.structured_sparsity_ << "\"\n" + << " \"layout_c\":\"" << problem.layout_c_ << "\",\n" + << " \"structured_sparsity\":" << (problem.structured_sparsity_ ? "true" : "false") + << "\n" << "}"; return os; } @@ -105,9 +112,8 @@ struct KernelInstance friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj) { os << "{\n" - << " \"name\": \"" << "{\n" - << obj.name_ << "\n}" << "\",\n" - << " \"problem\": \"" << obj.problem_ << "\",\n" + << " \"name\": \"" << obj.name_ << "\",\n" + << " \"problem\": " << obj.problem_ << ",\n" << " \"perf_result\": " << obj.perf_result_ << "\n" << "}"; return os; @@ -125,6 +131,7 @@ struct Setting std::string csv_filename_; bool flush_cache_; int rotating_count_; + bool json_output_; }; inline std::string get_rocm_version() diff --git a/tile_engine/ops/gemm/benchmark_gemm_single.cpp b/tile_engine/ops/gemm/benchmark_gemm_single.cpp new file mode 100644 index 0000000000..58532ffbe8 --- /dev/null +++ b/tile_engine/ops/gemm/benchmark_gemm_single.cpp @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "gemm_profiler.hpp" +#include "gemm_common.hpp" + +// The kernel header is included via the compile command line with -include flag +// It defines SelectedKernel struct and KERNEL_NAME +// DataTypeTraits are now defined in gemm_common.hpp + +// Create argument parser +inline auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.") + .insert("n", "4096", "The value for n dimension. Default is 4096.") + .insert("k", "2048", "The value for k dimension. Default is 2048.") + .insert("stride_a", "0", "The stride value for tensor A. Default is 0.") + .insert("stride_b", "0", "The stride value for tensor B. Default is 0.") + .insert("stride_c", "0", "The stride value for tensor C. Default is 0.") + .insert("split_k", "1", "The split value for k dimension. Default is 1.") + .insert("verify", + "0", + "The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 " + "for validation on GPU. Default is 0, no validation.") + .insert("log", + "false", + "Whether output kernel instance information or not. Possible values are true or " + "false. Default is false") + .insert( + "warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.") + .insert( + "repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.") + .insert("timer", + "true", + "Whether if the timer is gpu timer or not. Possible values are false or true. " + "Default is true.") + .insert("init", + "0", + "The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 " + "for constant(1). Default is 0, random.") + .insert("flush_cache", + "true", + "To flush cache, possible values are true or false. " + "Default is false.") + .insert("rotating_count", "1000", "number of iterations to rotate the cache. default is 5.") + .insert("metric", + "0", + "Metric with which to measure kernel performance. Set to 0 for latency, 1 for " + "tflops, or 2 for bandwidth. Default is 0, latency.") + .insert("csv_filename", + "", + "The filename of benchmark result. Default is empty (no CSV output).") + .insert("structured_sparsity", + "false", + "Whether use sparsity kernel or not. Possible values are true or false. Default is " + "false") + .insert("json_output", + "false", + "Whether to output results in JSON format only. Possible values are true or false. " + "Default is " + "false"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +void benchmark_gemm_single(const ck_tile::ArgParser& arg_parser) +{ + // Use DataTypeTraits to get the actual type names from the generated header + // The generated header defines ADataType, BDataType, AccDataType, CDataType + std::string dtype_a = DataTypeTraits::name; + std::string dtype_b = DataTypeTraits::name; + std::string dtype_acc = DataTypeTraits::name; + std::string dtype_c = DataTypeTraits::name; + + // Layout names from the layout types + std::string layout_a = ALayout::name; + std::string layout_b = BLayout::name; + std::string layout_c = CLayout::name; + + // Create GemmProblem struct + GemmProblem gemm_problem{arg_parser.get_int("split_k"), + arg_parser.get_int("m"), + arg_parser.get_int("n"), + arg_parser.get_int("k"), + arg_parser.get_int("stride_a"), + arg_parser.get_int("stride_b"), + arg_parser.get_int("stride_c"), + dtype_a, + dtype_b, + dtype_acc, + dtype_c, + layout_a, + layout_b, + layout_c, + arg_parser.get_bool("structured_sparsity")}; + + // Create Setting struct + Setting setting{arg_parser.get_int("warmup"), + arg_parser.get_int("repeat"), + arg_parser.get_bool("timer"), + arg_parser.get_int("verify"), + arg_parser.get_int("init"), + arg_parser.get_bool("log"), + arg_parser.get_str("csv_filename"), + arg_parser.get_bool("flush_cache"), + arg_parser.get_int("rotating_count"), + arg_parser.get_bool("json_output")}; + + // Get the profiler instance + auto& profiler = GemmProfiler::instance(setting); + + try + { + // Create a lambda that wraps the kernel launch + auto kernel_func = [](const ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& stream) { + return SelectedKernel::launch(args, stream); + }; + + // Benchmark the kernel + profiler.benchmark(gemm_problem, kernel_func); + + // Select best instance based on metric + profiler.select_best_instance(static_cast(arg_parser.get_int("metric"))); + } + catch(const std::exception& e) + { + std::cerr << "Benchmark failed: " << e.what() << std::endl; + } +} + +int main(int argc, char* argv[]) +{ + try + { + auto [result, parser] = create_args(argc, argv); + if(!result) + return EXIT_FAILURE; + + benchmark_gemm_single(parser); + return 0; + } + catch(const std::exception& e) + { + std::cerr << "Error: " << e.what() << "\n"; + return EXIT_FAILURE; + } +} diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index 392125aa0b..6a87193043 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -170,6 +170,14 @@ warp_tile_supported_combinations = { [16, 16, 128], [32, 32, 64], ], + "fp8_bf8_fp16": [ + [16, 16, 128], + [32, 32, 64], + ], + "bf8_fp8_fp16": [ + [16, 16, 128], + [32, 32, 64], + ], }, } diff --git a/tile_engine/ops/gemm/configs/benchmark.json b/tile_engine/ops/gemm/configs/benchmark.json index def3ca4453..b15b587147 100644 --- a/tile_engine/ops/gemm/configs/benchmark.json +++ b/tile_engine/ops/gemm/configs/benchmark.json @@ -5,20 +5,17 @@ "tile_m": { "max": 256, "min": 64, - "step": 64, - "exclude": [192] + "step": 64 }, "tile_n": { "max": 256, "min": 64, - "step": 64, - "exclude": [192] + "step": 64 }, "tile_k": { "max": 256, "min": 64, - "step": 64, - "exclude": [192] + "step": 64 }, "warp_m": { "values": [ @@ -79,7 +76,8 @@ }, "epilogue": { "values": [ - "cshuffle" + "cshuffle", + "default" ] }, "pad_m": { diff --git a/tile_engine/ops/gemm/configs/default_config.json b/tile_engine/ops/gemm/configs/default_config.json index 5bd51b809a..b245c3167f 100644 --- a/tile_engine/ops/gemm/configs/default_config.json +++ b/tile_engine/ops/gemm/configs/default_config.json @@ -1,105 +1,105 @@ { - "problem": { - }, - "tile_config": { - "tile_m": { - "values": [ - 256 - ] + "problem": { }, - "tile_n": { - "values": [ - 128, - 256 - ] + "tile_config": { + "tile_m": { + "max": 256, + "min": 64, + "step": 64 + }, + "tile_n": { + "max": 256, + "min": 64, + "step": 64 + }, + "tile_k": { + "max": 256, + "min": 64, + "step": 64 + }, + "warp_m": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_n": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 4, + 16, + 32 + ] + }, + "warp_tile_n": { + "values": [ + 16, + 32, + 64 + ] + }, + "warp_tile_k": { + "values": [ + 8, + 16, + 32, + 64, + 128 + ] + } }, - "tile_k": { - "values": [ - 32 - ] - }, - "warp_m": { - "values": [ - 1, - 2, - 4 - ] - }, - "warp_n": { - "values": [ - 1, - 2, - 4 - ] - }, - "warp_k": { - "values": [ - 1 - ] - }, - "warp_tile_m": { - "values": [ - 4, - 16, - 32 - ] - }, - "warp_tile_n": { - "values": [ - 16, - 32, - 64 - ] - }, - "warp_tile_k": { - "values": [ - 8, - 16, - 32, - 64, - 128 - ] + "trait_config": { + "pipeline": { + "values": [ + "compv3", + "compv4", + "mem" + ] + }, + "scheduler": { + "values": [ + "intrawave", + "interwave" + ] + }, + "epilogue": { + "values": [ + "cshuffle", + "default" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + }, + "persistent": { + "values": [ + false, + true + ] + } } - }, - "trait_config": { - "pipeline": { - "values": [ - "compv3", - "compv4", - "mem" - ] - }, - "scheduler": { - "values": [ - "intrawave", - "interwave" - ] - }, - "epilogue": { - "values": [ - "cshuffle", - "default" - ] - }, - "pad_m": { - "values": [ - false - ] - }, - "pad_n": { - "values": [ - false - ] - }, - "pad_k": { - "values": [ - false - ] - }, - "persistent": { - "values": [ - false - ] - } - } } diff --git a/tile_engine/ops/gemm/gemm_benchmark.py b/tile_engine/ops/gemm/gemm_benchmark.py new file mode 100755 index 0000000000..3b0f0e619d --- /dev/null +++ b/tile_engine/ops/gemm/gemm_benchmark.py @@ -0,0 +1,721 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +import sys +import json +import subprocess +import argparse +import csv +import time +from pathlib import Path +from typing import List, Dict, Tuple, Optional + + +class GemmBenchmark: + def __init__(self, build_dir: str, verbose: bool = False): + self.build_dir = Path(build_dir) + self.verbose = verbose + self.results = [] + + def discover_kernels(self) -> List[Path]: + """Find all benchmark_gemm_* executables in the build directory""" + bin_dir = self.build_dir / "bin" + if not bin_dir.exists(): + print(f"Error: Binary directory {bin_dir} does not exist") + return [] + + kernels = list(bin_dir.glob("benchmark_gemm_*")) + if self.verbose: + print(f"Found {len(kernels)} kernel executables") + for k in kernels: + print(f" - {k.name}") + return kernels + + def extract_kernel_info(self, kernel_path: Path) -> Dict[str, str]: + """Extract comprehensive kernel information from filename""" + name = kernel_path.stem + + # Initialize with basic info + info = { + "executable": str(kernel_path), + "name": name, + "data_type": "unknown", + "layout": "unknown", + "pipeline": "unknown", + "scheduler": "unknown", + "epilogue": "unknown", + } + + # Parse the kernel name pattern: + # benchmark_gemm_fp16_rcr_mem_default_intrawave_False_False_False_False_False_256x256x32_2x2x1_4x64x16 + parts = name.split("_") + + if len(parts) >= 3: + # Extract data type (3rd part after benchmark_gemm_) + info["data_type"] = parts[2] if len(parts) > 2 else "unknown" + + # Extract layout (4th part) + info["layout"] = parts[3] if len(parts) > 3 else "unknown" + + # Extract pipeline (5th part) + info["pipeline"] = parts[4] if len(parts) > 4 else "unknown" + + # Extract epilogue (6th part) + info["epilogue"] = parts[5] if len(parts) > 5 else "unknown" + + # Extract scheduler (7th part) + info["scheduler"] = parts[6] if len(parts) > 6 else "unknown" + + # Extract detailed configuration from the end of the name + config_info = self.parse_detailed_config(name) + info.update(config_info) + + # Generate config ID + info["config_id"] = self.generate_config_id(info) + + return info + + def parse_detailed_config(self, kernel_name: str) -> Dict: + """Parse detailed configuration from kernel name""" + config = { + "tile_sizes": {"tile_m": 0, "tile_n": 0, "tile_k": 0}, + "warp_config": {"warp_m": 0, "warp_n": 0, "warp_k": 0}, + "warp_tile": {"warp_tile_m": 0, "warp_tile_n": 0, "warp_tile_k": 0}, + "optimization_flags": { + "pad_m": False, + "pad_n": False, + "pad_k": False, + "persistent": False, + }, + } + + # Split by underscore and look for patterns + parts = kernel_name.split("_") + + # Look for boolean flags (sequence of True/False values) + bool_sequence = [] + for i, part in enumerate(parts): + if part in ["True", "False"]: + bool_sequence.append(part == "True") + # Continue collecting consecutive boolean values + j = i + 1 + while j < len(parts) and parts[j] in ["True", "False"]: + bool_sequence.append(parts[j] == "True") + j += 1 + break + + # Assign boolean flags if we found them + # Order: pad_m, pad_n, pad_k, persistent (4 flags total) + if len(bool_sequence) >= 4: + config["optimization_flags"]["pad_m"] = bool_sequence[0] + config["optimization_flags"]["pad_n"] = bool_sequence[1] + config["optimization_flags"]["pad_k"] = bool_sequence[2] + config["optimization_flags"]["persistent"] = bool_sequence[3] + + # Look for tile size patterns (e.g., 256x256x32_2x2x1_4x64x16) + # The pattern is: tile_sizes_warp_config_warp_tile + dimension_groups = [] + for part in parts: + if "x" in part and len(part.split("x")) == 3: + try: + dims = [int(x) for x in part.split("x")] + if all(d > 0 for d in dims): + dimension_groups.append(dims) + except ValueError: + continue + + # Assign dimensions based on order and magnitude + if len(dimension_groups) >= 3: + # Sort by magnitude to identify: largest=tile_sizes, smallest=warp_config, middle=warp_tile + sorted_groups = sorted(dimension_groups, key=lambda x: max(x), reverse=True) + + # Largest dimensions = tile sizes + config["tile_sizes"]["tile_m"] = sorted_groups[0][0] + config["tile_sizes"]["tile_n"] = sorted_groups[0][1] + config["tile_sizes"]["tile_k"] = sorted_groups[0][2] + + # Smallest dimensions = warp config + config["warp_config"]["warp_m"] = sorted_groups[2][0] + config["warp_config"]["warp_n"] = sorted_groups[2][1] + config["warp_config"]["warp_k"] = sorted_groups[2][2] + + # Middle dimensions = warp tile + config["warp_tile"]["warp_tile_m"] = sorted_groups[1][0] + config["warp_tile"]["warp_tile_n"] = sorted_groups[1][1] + config["warp_tile"]["warp_tile_k"] = sorted_groups[1][2] + elif len(dimension_groups) == 2: + # If only 2 groups, assign based on magnitude + sorted_groups = sorted(dimension_groups, key=lambda x: max(x), reverse=True) + + # Larger = tile sizes + config["tile_sizes"]["tile_m"] = sorted_groups[0][0] + config["tile_sizes"]["tile_n"] = sorted_groups[0][1] + config["tile_sizes"]["tile_k"] = sorted_groups[0][2] + + # Smaller = warp config + config["warp_config"]["warp_m"] = sorted_groups[1][0] + config["warp_config"]["warp_n"] = sorted_groups[1][1] + config["warp_config"]["warp_k"] = sorted_groups[1][2] + elif len(dimension_groups) == 1: + # Only one group - assume it's tile sizes + config["tile_sizes"]["tile_m"] = dimension_groups[0][0] + config["tile_sizes"]["tile_n"] = dimension_groups[0][1] + config["tile_sizes"]["tile_k"] = dimension_groups[0][2] + + return config + + def generate_config_id(self, info: Dict) -> str: + """Generate a compact config ID from kernel info""" + # Create a compact identifier + parts = [ + info.get("data_type", "unk"), + info.get("layout", "unk"), + info.get("pipeline", "unk"), + info.get("scheduler", "unk"), + ] + + # Add tile configuration if available + tile_sizes = info.get("tile_sizes", {}) + if tile_sizes.get("tile_m", 0) > 0: + tile_str = ( + f"{tile_sizes['tile_m']}x{tile_sizes['tile_n']}x{tile_sizes['tile_k']}" + ) + parts.append(tile_str) + + # Add warp config if available + warp_config = info.get("warp_config", {}) + if warp_config.get("warp_m", 0) > 0: + warp_str = f"w{warp_config['warp_m']}x{warp_config['warp_n']}x{warp_config['warp_k']}" + parts.append(warp_str) + + # Add warp tile if available + warp_tile = info.get("warp_tile", {}) + if warp_tile.get("warp_tile_m", 0) > 0: + warp_tile_str = f"wt{warp_tile['warp_tile_m']}x{warp_tile['warp_tile_n']}x{warp_tile['warp_tile_k']}" + parts.append(warp_tile_str) + + return "_".join(parts) + + def run_kernel(self, kernel_path: Path, params: Dict[str, str]) -> Optional[Dict]: + """Run a single kernel with given parameters and save output to individual JSON file""" + # Create results directory + results_dir = self.build_dir / "results" + results_dir.mkdir(exist_ok=True) + + # Generate unique JSON filename for this kernel + json_file = results_dir / f"{kernel_path.stem}.json" + + cmd = [str(kernel_path)] + + # Add parameters + for key, value in params.items(): + cmd.append(f"-{key}={value}") + + # Add JSON output flag for clean JSON output + cmd.append("-json_output=true") + + if self.verbose: + print(f"Running: {' '.join(cmd)}") + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + + if result.returncode != 0: + print(f"Error running {kernel_path.name}: {result.stderr}") + return None + + # Save raw output to individual JSON file + output = result.stdout.strip() + if output: + with open(json_file, "w") as f: + f.write(output) + + # Parse the JSON file + return self.parse_json_file(json_file) + else: + print(f"No output from {kernel_path.name}") + return None + + except subprocess.TimeoutExpired: + print(f"Timeout running {kernel_path.name}") + return None + except Exception as e: + print(f"Error running {kernel_path.name}: {e}") + return None + + def parse_json_file(self, json_file: Path) -> Optional[Dict]: + """Parse JSON data from individual kernel output file""" + try: + with open(json_file, "r") as f: + content = f.read().strip() + + # Parse the JSON directly since executables produce clean JSON + data = json.loads(content) + + # Return the complete JSON data as-is, just add some convenience fields + result = data.copy() + if "perf_result" in data: + perf = data["perf_result"] + # Add convenience fields for backward compatibility + result["time_ms"] = perf.get("latency(ms)", 0) + result["tflops"] = perf.get("tflops(TFlops)", 0) + result["bandwidth_gb_s"] = perf.get("bandwidth(GB/s)", 0) + + return result + + except json.JSONDecodeError as e: + if self.verbose: + print(f"Failed to parse JSON from {json_file}: {e}") + return None + except Exception as e: + if self.verbose: + print(f"Error reading JSON file {json_file}: {e}") + return None + + def parse_benchmark_output(self, output: str) -> Optional[Dict]: + """Parse the benchmark output format - extract JSON directly""" + try: + # Find JSON block between asterisk markers + lines = output.split("\n") + json_start = -1 + json_end = -1 + + for i, line in enumerate(lines): + if line.strip().startswith("{"): + json_start = i + elif line.strip().endswith("}") and json_start != -1: + json_end = i + break + + if json_start != -1 and json_end != -1: + json_text = "\n".join(lines[json_start : json_end + 1]) + data = json.loads(json_text) + + # Return the complete JSON data as-is, just add some convenience fields + result = data.copy() + if "perf_result" in data: + perf = data["perf_result"] + # Add convenience fields for backward compatibility + result["time_ms"] = perf.get("latency(ms)", 0) + result["tflops"] = perf.get("tflops(TFlops)", 0) + result["bandwidth_gb_s"] = perf.get("bandwidth(GB/s)", 0) + + return result + + return None + + except json.JSONDecodeError as e: + if self.verbose: + print(f"Failed to parse JSON: {e}") + print(f"Output was: {output[:200]}...") + return None + except Exception as e: + if self.verbose: + print(f"Error parsing output: {e}") + return None + + def benchmark_problem_size( + self, + kernels: List[Path], + m: int, + n: int, + k: int, + split_k: int = 1, + verify: int = 0, + warmup: int = 50, + repeat: int = 100, + flush_cache: bool = True, + rotating_count: int = 1000, + ) -> List[Dict]: + """Benchmark all kernels for a specific problem size""" + results = [] + + params = { + "m": m, + "n": n, + "k": k, + "split_k": split_k, + "verify": verify, + "warmup": warmup, + "repeat": repeat, + "flush_cache": str(flush_cache).lower(), + "rotating_count": rotating_count, + } + + print(f"\nBenchmarking M={m}, N={n}, K={k}, split_k={split_k}") + + for kernel_path in kernels: + kernel_info = self.extract_kernel_info(kernel_path) + result = self.run_kernel(kernel_path, params) + + if result: + # Create new structured result format + structured_result = { + "name": kernel_info["name"], # Add name field for compatibility + "config_id": kernel_info["config_id"], + "problem": result.get("problem", {}), + "perf_result": result.get("perf_result", {}), + "config": { + "data_type": kernel_info["data_type"], + "layout": kernel_info["layout"], + "pipeline": kernel_info["pipeline"], + "scheduler": kernel_info["scheduler"], + "epilogue": kernel_info["epilogue"], + "tile_sizes": kernel_info.get("tile_sizes", {}), + "warp_config": kernel_info.get("warp_config", {}), + "warp_tile": kernel_info.get("warp_tile", {}), + "optimization_flags": kernel_info.get("optimization_flags", {}), + }, + "executable": kernel_info["executable"], + # Keep backward compatibility fields + "time_ms": result.get("time_ms", 0), + "tflops": result.get("tflops", 0), + "bandwidth_gb_s": result.get("bandwidth_gb_s", 0), + } + + results.append(structured_result) + + if self.verbose: + print( + f" {kernel_info['config_id']}: {structured_result['tflops']:.2f} TFLOPS, {structured_result['bandwidth_gb_s']:.2f} GB/s, {structured_result['time_ms']:.2f}ms" + ) + + return results + + def find_best_kernel( + self, results: List[Dict], metric: str = "tflops" + ) -> Optional[Dict]: + """Find the best performing kernel based on metric""" + if not results: + return None + + if metric == "tflops": + return max(results, key=lambda x: x.get("tflops", 0)) + elif metric == "time_ms": + return min(results, key=lambda x: x.get("time_ms", float("inf"))) + elif metric == "bandwidth_gb_s": + return max(results, key=lambda x: x.get("bandwidth_gb_s", 0)) + else: + raise ValueError(f"Unknown metric: {metric}") + + def benchmark_sweep( + self, + problem_sizes: List[Tuple[int, int, int]], + split_k_values: List[int] = [1], + verify: bool = False, + warmup: int = 50, + repeat: int = 100, + flush_cache: bool = True, + rotating_count: int = 1000, + ) -> Dict: + """Run comprehensive benchmark sweep""" + kernels = self.discover_kernels() + if not kernels: + print("No kernels found!") + return {} + + all_results = [] + best_kernels = {} + + for m, n, k in problem_sizes: + for split_k in split_k_values: + results = self.benchmark_problem_size( + kernels, + m, + n, + k, + split_k, + verify=2 if verify else 0, + warmup=warmup, + repeat=repeat, + flush_cache=flush_cache, + rotating_count=rotating_count, + ) + + all_results.extend(results) + + # Find best kernel for this configuration + best = self.find_best_kernel(results) + if best: + key = f"m{m}_n{n}_k{k}_splitk{split_k}" + best_kernels[key] = best + print( + f"Best for {key}: {best['name']} ({best['tflops']:.2f} TFLOPS, {best['bandwidth_gb_s']:.2f} GB/s, {best['time_ms']:.2f}ms)" + ) + + self.results = all_results + return best_kernels + + def export_csv(self, filename: str): + """Export all results to CSV""" + if not self.results: + print("No results to export") + return + + # Get all unique keys from results + all_keys = set() + for result in self.results: + all_keys.update(result.keys()) + + # Sort keys for consistent output + fieldnames = sorted(all_keys) + + with open(filename, "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(self.results) + + print(f"Results exported to {filename}") + + def export_best_kernels(self, best_kernels: Dict, filename: str): + """Export best kernel selections to file""" + with open(filename, "w") as f: + f.write("# Best kernel selections\n") + f.write( + "# Format: problem_size -> kernel_name (TFLOPS, bandwidth, latency)\n\n" + ) + + for key, kernel in sorted(best_kernels.items()): + f.write( + f"{key}: {kernel['name']} ({kernel['tflops']:.2f} TFLOPS, {kernel['bandwidth_gb_s']:.2f} GB/s, {kernel['time_ms']:.2f}ms)\n" + ) + + print(f"Best kernels exported to {filename}") + + def export_json(self, filename: str, best_kernels: Dict = None): + """Export all results and best kernels to JSON with comprehensive metadata""" + from datetime import datetime + + # Calculate comprehensive summary statistics for all metrics + successful_results = [r for r in self.results if r.get("tflops", 0) > 0] + + tflops_values = [r.get("tflops", 0) for r in successful_results] + bandwidth_values = [r.get("bandwidth_gb_s", 0) for r in successful_results] + latency_values = [ + r.get("time_ms", 0) for r in successful_results if r.get("time_ms", 0) > 0 + ] + + # Performance breakdown by kernel type + pipeline_stats = {} + scheduler_stats = {} + data_type_stats = {} + + for result in successful_results: + # Get config info from the new structure + config = result.get("config", {}) + + # Pipeline statistics + pipeline = config.get("pipeline", "unknown") + if pipeline not in pipeline_stats: + pipeline_stats[pipeline] = { + "count": 0, + "avg_tflops": 0, + "best_tflops": 0, + } + pipeline_stats[pipeline]["count"] += 1 + pipeline_stats[pipeline]["best_tflops"] = max( + pipeline_stats[pipeline]["best_tflops"], result.get("tflops", 0) + ) + + # Scheduler statistics + scheduler = config.get("scheduler", "unknown") + if scheduler not in scheduler_stats: + scheduler_stats[scheduler] = { + "count": 0, + "avg_tflops": 0, + "best_tflops": 0, + } + scheduler_stats[scheduler]["count"] += 1 + scheduler_stats[scheduler]["best_tflops"] = max( + scheduler_stats[scheduler]["best_tflops"], result.get("tflops", 0) + ) + + # Data type statistics + data_type = config.get("data_type", "unknown") + if data_type not in data_type_stats: + data_type_stats[data_type] = { + "count": 0, + "avg_tflops": 0, + "best_tflops": 0, + } + data_type_stats[data_type]["count"] += 1 + data_type_stats[data_type]["best_tflops"] = max( + data_type_stats[data_type]["best_tflops"], result.get("tflops", 0) + ) + + # Calculate averages for breakdown stats + for stats_dict, field_name in [ + (pipeline_stats, "pipeline"), + (scheduler_stats, "scheduler"), + (data_type_stats, "data_type"), + ]: + for key in stats_dict: + relevant_results = [ + r + for r in successful_results + if r.get("config", {}).get(field_name, "unknown") == key + ] + if relevant_results: + stats_dict[key]["avg_tflops"] = sum( + r.get("tflops", 0) for r in relevant_results + ) / len(relevant_results) + + output_data = { + "benchmark_metadata": { + "timestamp": datetime.now().isoformat(), + "total_kernels_tested": len(self.results), + "unique_kernels": len( + set(r.get("name", "unknown") for r in self.results) + ), + "successful_runs": len(successful_results), + "failed_runs": len(self.results) - len(successful_results), + }, + "performance_summary": { + "tflops_stats": { + "best": max(tflops_values, default=0), + "average": sum(tflops_values) / len(tflops_values) + if tflops_values + else 0, + "min": min(tflops_values, default=0), + "median": sorted(tflops_values)[len(tflops_values) // 2] + if tflops_values + else 0, + }, + "bandwidth_stats": { + "best_gb_s": max(bandwidth_values, default=0), + "average_gb_s": sum(bandwidth_values) / len(bandwidth_values) + if bandwidth_values + else 0, + "min_gb_s": min(bandwidth_values, default=0), + "median_gb_s": sorted(bandwidth_values)[len(bandwidth_values) // 2] + if bandwidth_values + else 0, + }, + "latency_stats": { + "best_ms": min(latency_values, default=0), + "average_ms": sum(latency_values) / len(latency_values) + if latency_values + else 0, + "max_ms": max(latency_values, default=0), + "median_ms": sorted(latency_values)[len(latency_values) // 2] + if latency_values + else 0, + }, + "kernel_type_breakdown": { + "by_pipeline": pipeline_stats, + "by_scheduler": scheduler_stats, + "by_data_type": data_type_stats, + }, + "total_problem_configurations": len(best_kernels) + if best_kernels + else 0, + }, + "kernel_results": self.results, + "best_kernels_by_problem": best_kernels or {}, + } + + with open(filename, "w") as f: + json.dump(output_data, f, indent=2) + + print(f"JSON results exported to {filename}") + print(f" - Total kernels: {len(self.results)}") + print(f" - Successful runs: {len(successful_results)}") + print(f" - Best TFLOPS: {max(tflops_values, default=0):.2f}") + print(f" - Best bandwidth: {max(bandwidth_values, default=0):.2f} GB/s") + print(f" - Best latency: {min(latency_values, default=0):.2f}ms") + + +def main(): + parser = argparse.ArgumentParser(description="GEMM Kernel Benchmarking Tool") + parser.add_argument( + "build_dir", help="Build directory containing kernel executables" + ) + parser.add_argument( + "--problem-sizes", + nargs="+", + default=["1024,1024,1024", "2048,2048,2048", "4096,4096,4096"], + help="Problem sizes as M,N,K tuples", + ) + parser.add_argument( + "--split-k", nargs="+", type=int, default=[1], help="Split-K values to test" + ) + parser.add_argument("--verify", action="store_true", help="Enable verification") + parser.add_argument( + "--csv", default="gemm_benchmark_results.csv", help="CSV output filename" + ) + parser.add_argument( + "--best", default="best_kernels.txt", help="Best kernels output filename" + ) + parser.add_argument("--verbose", action="store_true", help="Verbose output") + parser.add_argument( + "--warmup", + type=int, + default=50, + help="Number of warmup iterations (default: 50)", + ) + parser.add_argument( + "--repeat", + type=int, + default=100, + help="Number of benchmark iterations (default: 100)", + ) + parser.add_argument( + "--flush-cache", + action="store_true", + default=True, + help="Enable cache flushing (default: True)", + ) + parser.add_argument( + "--rotating-count", + type=int, + default=1000, + help="Number of iterations to rotate cache (default: 1000)", + ) + parser.add_argument("--json", help="JSON output filename (optional)") + + args = parser.parse_args() + + # Parse problem sizes + problem_sizes = [] + for size_str in args.problem_sizes: + try: + m, n, k = map(int, size_str.split(",")) + problem_sizes.append((m, n, k)) + except ValueError: + print(f"Invalid problem size: {size_str}") + return 1 + + # Create benchmark instance + benchmark = GemmBenchmark(args.build_dir, verbose=args.verbose) + + # Run benchmark sweep + print("Starting GEMM kernel benchmark sweep...") + start_time = time.time() + + best_kernels = benchmark.benchmark_sweep( + problem_sizes=problem_sizes, + split_k_values=args.split_k, + verify=args.verify, + warmup=args.warmup, + repeat=args.repeat, + flush_cache=args.flush_cache, + rotating_count=args.rotating_count, + ) + + elapsed_time = time.time() - start_time + print(f"\nBenchmark completed in {elapsed_time:.2f} seconds") + + # Export results + benchmark.export_csv(args.csv) + benchmark.export_best_kernels(best_kernels, args.best) + + # Export JSON if requested + if args.json: + benchmark.export_json(args.json, best_kernels) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tile_engine/ops/gemm/gemm_common.hpp b/tile_engine/ops/gemm/gemm_common.hpp new file mode 100644 index 0000000000..5188915f1a --- /dev/null +++ b/tile_engine/ops/gemm/gemm_common.hpp @@ -0,0 +1,197 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/pk_int4.hpp" + +// DataTypeTraits for all supported types +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +// Helper function to determine if a layout is row-major +template +constexpr auto is_row_major(Layout) +{ + return ck_tile::bool_constant>{}; +} + +// Permutation function for pk_int4_t +template +void permute_vectors_i4x4_b(Tensor& tensor) +{ + const ck_tile::index_t K = tensor.get_length(0); + const ck_tile::index_t N = tensor.get_length(1); + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int8_t input[8]; + + for(int k = 0; k < 4; k++) + { + int8_t i4x2 = tensor(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int8_t hi = input[2]; + int8_t lo = input[0]; + int8_t i4x2 = (hi << 4) | lo; + tensor(j + 0, i) = i4x2; + } + + { + int8_t hi = input[6]; + int8_t lo = input[4]; + int8_t i4x2 = (hi << 4) | lo; + tensor(j + 2, i) = i4x2; + } + + { + int8_t hi = input[3]; + int8_t lo = input[1]; + int8_t i4x2 = (hi << 4) | lo; + tensor(j + 4, i) = i4x2; + } + + { + int8_t hi = input[7]; + int8_t lo = input[5]; + int8_t i4x2 = (hi << 4) | lo; + tensor(j + 6, i) = i4x2; + } + } + } +} + +// Structure to hold kernel traits for dispatcher +struct KernelTraits +{ + std::string pipeline; // compv3, compv4, mem + std::string scheduler; // intrawave, interwave + std::string epilogue; // cshuffle, default + bool pad_m; + bool pad_n; + bool pad_k; + bool persistent; + + // Constructor with defaults + KernelTraits() + : pipeline("compv3"), + scheduler("intrawave"), + epilogue("cshuffle"), + pad_m(false), + pad_n(false), + pad_k(false), + persistent(false) + { + } +}; + +// Helper to extract traits from kernel name +inline KernelTraits extract_traits_from_name(const std::string& kernel_name) +{ + KernelTraits traits; + + // Extract pipeline + if(kernel_name.find("compv3") != std::string::npos) + { + traits.pipeline = "compv3"; + } + else if(kernel_name.find("compv4") != std::string::npos) + { + traits.pipeline = "compv4"; + } + else if(kernel_name.find("mem") != std::string::npos) + { + traits.pipeline = "mem"; + } + + // Extract scheduler + if(kernel_name.find("interwave") != std::string::npos) + { + traits.scheduler = "interwave"; + } + else + { + traits.scheduler = "intrawave"; + } + + // Extract epilogue + if(kernel_name.find("default") != std::string::npos && + kernel_name.find("default_") == std::string::npos) + { + traits.epilogue = "default"; + } + else + { + traits.epilogue = "cshuffle"; + } + + // Padding flags would need to be extracted from the kernel configuration + // For now, we'll leave them as false + + return traits; +} diff --git a/tile_engine/ops/gemm/gemm_host_api.hpp b/tile_engine/ops/gemm/gemm_host_api.hpp deleted file mode 100644 index f28f5dd29c..0000000000 --- a/tile_engine/ops/gemm/gemm_host_api.hpp +++ /dev/null @@ -1,223 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include - -#include "ck_tile/host.hpp" -#include "gemm_dispatcher.hpp" -#include "gemm_common.hpp" - -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - -template -static constexpr inline auto is_row_major(Layout layout_) -{ - return ck_tile::bool_constant, - ck_tile::tensor_layout::gemm::RowMajor>>{}; -} - -inline auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.") - .insert("n", "4096", "The value for n dimension. Default is 4096.") - .insert("k", "2048", "The value for k dimension. Default is 2048.") - .insert("stride_a", "0", "The stride value for tensor A. Default is 0.") - .insert("stride_b", "0", "The stride value for tensor B. Default is 0.") - .insert("stride_c", "0", "The stride value for tensor C Default is 0.") - .insert("split_k", "1", "The split value for k dimension. Default is 1.") - .insert("verify", - "2", - "The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 " - "for validation on GPU. Default is 2, validation on GPU.") - .insert("log", - "false", - "Wether output kernel instance information or not. Possible values are true or " - "false. Default is false") - .insert( - "warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.") - .insert( - "repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.") - .insert("timer", - "true", - "Whether if the timer is gpu timer or not. Possible values are false or true. " - "Default is true.") - .insert("init", - "0", - "The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 " - "for constant(1). Default is 0, random.") - .insert("flush_cache", - "false", - "To flush cache, possible values are true or false. " - "Default is false.") - .insert("rotating_count", "5", "number of iterations to rotate the cache. default is 5.") - .insert("metric", - "0", - "Metric with which to measure kernel performance. Set to 0 for latency, 1 for " - "tflops, or 2 for bandwidth. Default is 0, latency.") - .insert("csv_filename", - "gemm_kernel", - "The filename of benchmark result. Default is gemm_kernel.") - .insert("structured_sparsity", - "false", - "Whether use sparsity kernel or not. Possible values are true or false. Default is " - "false") - .insert( - "pipeline", - "compv3", - "The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.") - .insert("scheduler", - "intrawave", - "The type of pipeline. Possible values are compv3, compv4 or mem. Default is " - "compv3.") - .insert( - "epilogue", - "cshuffle", - "The type of epilogue. Possible values are cshuffle or default. Default is csshuffle.") - .insert("pad_m", - "false", - "Whether pad or not in m direction. Possible values are true or false. Default is " - "false.") - .insert("pad_n", - "false", - "Whether pad or not in n direction. Possible values are true or false. Default is " - "false.") - .insert("pad_k", - "false", - "Whether pad or not in k direction. Possible values are true or false. Default is " - "false.") - .insert("persistent", "false", "Whether to use persistent kernel. Default is false."); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -template -void permute_vectors_i4x4_b(Tensor& tensor) -{ - const ck_tile::index_t K = tensor.get_length(0); - const ck_tile::index_t N = tensor.get_length(1); - // vector pk_i4x4 permute - for(int i = 0; i < N; i++) - { - for(int j = 0; j < K; j += 8) - { - int8_t input[8]; - - for(int k = 0; k < 4; k++) - { - int8_t i4x2 = tensor(j + k * 2, i).data; - input[k * 2 + 0] = (i4x2 >> 4) & 0xf; - input[k * 2 + 1] = (i4x2 >> 0) & 0xf; - } - - // permute 01234567->20643175 - { - int8_t hi = input[2]; - int8_t lo = input[0]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 0, i) = i4x2; - } - - { - int8_t hi = input[6]; - int8_t lo = input[4]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 2, i) = i4x2; - } - - { - int8_t hi = input[3]; - int8_t lo = input[1]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 4, i) = i4x2; - } - - { - int8_t hi = input[7]; - int8_t lo = input[5]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 6, i) = i4x2; - } - } - } -} - -auto get_kernel_func_by_trait(const ck_tile::ArgParser& arg_parser) -{ - KernelTraits trait; - trait.pipeline = arg_parser.get_str("pipeline"); - trait.scheduler = arg_parser.get_str("scheduler"); - trait.epilogue = arg_parser.get_str("epilogue"); - trait.pad_m = arg_parser.get_bool("pad_m"); - trait.pad_n = arg_parser.get_bool("pad_n"); - trait.pad_k = arg_parser.get_bool("pad_k"); - trait.persistent = arg_parser.get_bool("persistent"); - - bool structured_sparsity = arg_parser.get_bool("structured_sparsity"); - - return GemmDispatcher::dispatch(structured_sparsity, trait); -} diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py old mode 100755 new mode 100644 index 7def4e2691..d679be7b84 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -1,361 +1,597 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -# -*- coding: utf-8 -*- - -""" -generate kernel instances to speed up compilation -""" +#!/usr/bin/env python +import os +import json import argparse import itertools +import multiprocessing +import concurrent.futures from pathlib import Path -from typing import List, Optional -from json_config import GemmConfig, RangeConfigParam -from codegen_utils import ( - DATA_TYPE_MAP, - LAYOUT_MAP, - PIPELINE_MAP, - SCHEDULER_MAP, - EPILOGUE_MAP, - BOOL_MAP, - warp_tile_supported_combinations, - trait_unsupported_combinations, - element_size, - get_gpu_name_by_id, -) import logging +from validation_utils import is_tile_config_valid, is_trait_combination_valid logging.basicConfig(level=logging.INFO) -class GemmCodeGenerator: - """GEMM (General Matrix Multiplication) code generator.""" +class GemmKernelBuilder: + def __init__(self, working_path, datatype, layout, config_json=None): + self.working_path = Path(working_path) + self.datatype = datatype + self.layout = layout + self.config_json = config_json - def __init__( - self, output_dir: str, user_provided_config: Optional[GemmConfig] = None - ): - self.output_dir = Path(output_dir) - self.output_dir.mkdir(parents=True, exist_ok=True) + # Create working directory if it doesn't exist + self.working_path.mkdir(parents=True, exist_ok=True) - if user_provided_config is not None: - self.config = user_provided_config + # Load configuration + if config_json and os.path.exists(config_json): + with open(config_json, "r") as f: + self.config = json.load(f) else: - config_path = ( - Path(__file__).resolve().parent / "configs" / "default_config.json" + self.config = self._get_default_config() + + def _get_default_config(self): + """Return default configuration if no config file is provided""" + # Define base tile configurations that work for all layouts + base_fp16_configs = [ + { + "tile_m": 256, + "tile_n": 256, + "tile_k": 32, + "warp_m": 2, + "warp_n": 2, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 32, + }, + { + "tile_m": 256, + "tile_n": 128, + "tile_k": 32, + "warp_m": 2, + "warp_n": 2, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 16, + }, + ] + + base_fp8_configs = [ + { + "tile_m": 256, + "tile_n": 256, + "tile_k": 32, + "warp_m": 4, + "warp_n": 1, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 32, + }, + { + "tile_m": 256, + "tile_n": 128, + "tile_k": 32, + "warp_m": 1, + "warp_n": 4, + "warp_k": 1, + "warp_tile_m": 16, + "warp_tile_n": 16, + "warp_tile_k": 32, + }, + ] + + # Create configurations for all supported layouts + all_layouts = ["rcr", "rrr", "ccr", "crr"] + tile_configs = {} + + for datatype, base_configs in [ + ("fp16", base_fp16_configs), + ("fp8", base_fp8_configs), + ]: + tile_configs[datatype] = {} + for layout in all_layouts: + tile_configs[datatype][layout] = base_configs + + return { + "tile_configs": tile_configs, + "traits": { + "pipelines": ["mem", "compv3", "compv4"], + "epilogues": ["default", "cshuffle"], + "schedulers": ["intrawave", "interwave"], + }, + "structured_sparsity": ["false"], + "padding": {"pad_m": ["false"], "pad_n": ["false"], "pad_k": ["false"]}, + "persistent": ["false"], + } + + def _get_tile_configs(self, fast_mode=False): + """Get tile configurations for the current datatype and layout""" + if "tile_configs" in self.config: + # Old format + return ( + self.config["tile_configs"].get(self.datatype, {}).get(self.layout, []) ) - self.config = GemmConfig.from_json(config_path) + elif "tile_config" in self.config: + # New format - generate combinations from individual parameter values + tile_config = self.config["tile_config"] - self.valid_trait_names: List[str] = [] - self.valid_trait_tile_combinations: map[str, list[tuple[int]]] = {} + # Get all possible values for each parameter + tile_m_values = tile_config.get("tile_m", {}).get("values", [256]) + tile_n_values = tile_config.get("tile_n", {}).get("values", [256]) + tile_k_values = tile_config.get("tile_k", {}).get("values", [32]) + warp_m_values = tile_config.get("warp_m", {}).get("values", [2]) + warp_n_values = tile_config.get("warp_n", {}).get("values", [2]) + warp_k_values = tile_config.get("warp_k", {}).get("values", [1]) + warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [32]) + warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [32]) + warp_tile_k_values = tile_config.get("warp_tile_k", {}).get("values", [32]) - def list_all_trait_names(self): - """List all possible kernel trait names into file.""" - w_p = Path(self.output_dir) - file_path = w_p / "gemm_instance_blobs.txt" - self._generate_all_traits() - self._get_valid_trait_tile_combinations() - file_range_map = {} - # Write all file paths to the header file - files_listed = 0 - with file_path.open("w") as f: - # Core files - core_files = [ - "gemm_common.hpp", - "gemm_instances.hpp", - "gemm_dispatcher.hpp", - ] - for core_file in core_files: - f.write(str(w_p / core_file) + "\n") - files_listed += 1 - - # Trait header files - for trait in self.valid_trait_names: - trait_file = f"gemm_{trait}.hpp" - f.write(str(w_p / trait_file) + "\n") - files_listed += 1 - file_name = set() - # Instance source files - for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): - start_idx = files_listed - for tile in tile_valid_params: - for ( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - _, - _, - _, - ) in tile: - instance_name = f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp" - - if instance_name not in file_name: - file_name.add(instance_name) - f.write(str(w_p / instance_name) + "\n") - files_listed += 1 - - file_range_map[trait] = (start_idx, files_listed) - - file_path = w_p / "gemm_instance_blobs_range.txt" - with file_path.open("w") as f: - for name, ranges in file_range_map.items(): - s, l = ranges - f.write(name + " " + f"{s}" + " " + f"{l}" + "\n") - - def _generate_all_traits(self): - """Generate all possible kernel traits names.""" - params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k", "persistent"] - - # Generate all unique_combinations - _unique = set( - itertools.product( - *[getattr(self.config.trait_config, param).values for param in params] - ) - ) - - for combo in _unique: - pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = combo - current_combination = (pipeline, epilogue, scheduler) - - if current_combination not in trait_unsupported_combinations: - trait_name = ( - f"{pipeline}_{epilogue}_{scheduler}_" - f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}_" - f"{BOOL_MAP(persistent)}" - ) - self.valid_trait_names.append(trait_name) - else: - logging.debug(f"Invalid combination: {pipeline}-{epilogue}-{scheduler}") - - def generate_all_instance_files(self): - """Generate all kernel instances files.""" - self._generate_common_header_file() - self._generate_all_trait_files() - self._generate_dispatcher_file() - - def _generate_common_header_file(self): - """Generate common header file with datatypes and layout.""" - - # Determine appropriate accumulation type based on input types - a_type = self.config.problem.datatype_map["matrix_a"] - b_type = self.config.problem.datatype_map["matrix_b"] - c_type = self.config.problem.datatype_map["matrix_c"] - - if a_type in ["int8", "int4"] and b_type in ["int8", "int4"]: - acc_type = "ck_tile::int32_t" + # Generate all combinations + configs = [] + for tile_m in tile_m_values: + for tile_n in tile_n_values: + for tile_k in tile_k_values: + for warp_m in warp_m_values: + for warp_n in warp_n_values: + for warp_k in warp_k_values: + for warp_tile_m in warp_tile_m_values: + for warp_tile_n in warp_tile_n_values: + for warp_tile_k in warp_tile_k_values: + # Validate configuration + if self._validate_tile_config( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + fast_mode=fast_mode, + ): + configs.append( + { + "tile_m": tile_m, + "tile_n": tile_n, + "tile_k": tile_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "warp_tile_m": warp_tile_m, + "warp_tile_n": warp_tile_n, + "warp_tile_k": warp_tile_k, + } + ) + return configs else: - acc_type = "float" + # Fallback to default + return [] - content = f"""// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/common.hpp" - -// Data types -using ADataType = {DATA_TYPE_MAP[self.config.problem.datatype_map["matrix_a"]]}; -using BDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map["matrix_b"]]}; -using AccDataType = {acc_type}; -using CDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map["matrix_c"]]}; - -// Layout configurations -using ALayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_a"]]}; -using BLayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_b"]]}; -using CLayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_c"]]}; -""" - - (self.output_dir / "gemm_common.hpp").write_text(content) - - def _generate_all_trait_files(self): - """Generate all kernel traits into files.""" - if not self.valid_trait_names: - self._generate_all_traits() - self._get_valid_trait_tile_combinations() - for trait in self.valid_trait_names: - self._generate_trait_file(trait) - self._generate_instantiation_source_files() - self._generate_common_instance_header_file() - - def _generate_trait_file(self, trait: str): - """Generate a trait with all tile/warp combinations.""" - pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = trait.split("_") - filename = f"gemm_{trait}.hpp" - - content = f"""// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "gemm_common.hpp" -#include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/host.hpp" - -namespace {trait} {{ -""" - # Add template struct with configuration - content += self._generate_kernel_struct( - pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent) - - content += f"\n}} // namespace {trait}\n" - (self.output_dir / filename).write_text(content) - - def _generate_kernel_struct( + def _validate_tile_config( self, - pipeline: str, - epilogue: str, - scheduler: str, - pad_m: str, - pad_n: str, - pad_k: str, - persistent: str, - ) -> str: - """Generate the code block of kernel struct""" - return f""" + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + pipeline="mem", # Default pipeline for validation + fast_mode=False, # Add fast mode option + ): + """Validate that tile configuration is reasonable""" + if fast_mode: + # Fast validation for listing - only basic sanity checks + if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: + return False + if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: + return False + if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: + return False -template -struct GemmKernel {{ - static constexpr bool kPadM = {pad_m}; - static constexpr bool kPadN = {pad_n}; - static constexpr bool kPadK = {pad_k}; - static constexpr bool kPersistent = {persistent}; + # Basic divisibility check + if tile_m % (warp_m * warp_tile_m) != 0: + return False + if tile_n % (warp_n * warp_tile_n) != 0: + return False + if tile_k % (warp_k * warp_tile_k) != 0: + return False - static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ - static constexpr bool permuteA = false; - static constexpr bool permuteB = false; - static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"}; - static constexpr bool TransposeC = false; + return True + else: + # Full validation for generation + # Determine data types for validation + a_datatype = self.datatype + b_datatype = self.datatype + c_datatype = self.datatype - static constexpr int kBlockPerCu = 1; - static constexpr ck_tile::index_t TileParitionerGroupNum = 8; - static constexpr ck_tile::index_t TileParitionerM01 = 4; + # Special handling for certain data types + if self.datatype in ["fp8", "bf8"]: + c_datatype = "fp16" - using GemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence, - permuteA, - permuteB>; + # Use the comprehensive validation function + return is_tile_config_valid( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + pipeline, + ) - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + def _generate_trait_combinations(self): + """Generate all combinations of traits""" + if "traits" in self.config: + # Old format + traits = self.config["traits"] + pipelines = traits["pipelines"] + epilogues = traits["epilogues"] + schedulers = traits["schedulers"] - using Traits = - ck_tile::TileGemmTraits; + padding = self.config["padding"] + persistent = self.config["persistent"] - using GemmUniversalTraits = - ck_tile::TileGemmUniversalTraits; + all_combinations = list( + itertools.product( + pipelines, + epilogues, + schedulers, + padding["pad_m"], + padding["pad_n"], + padding["pad_k"], + persistent, + ) + ) - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + # Filter out unsupported trait combinations + combinations = [] + for combo in all_combinations: + pipeline, epilogue, scheduler = combo[:3] + if is_trait_combination_valid(pipeline, epilogue, scheduler): + combinations.append(combo) + else: + logging.debug( + f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}" + ) - using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}; + elif "trait_config" in self.config: + # New format + trait_config = self.config["trait_config"] - const ck_tile::index_t k_grain = args.k_batch * TileK; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + pipelines = trait_config.get("pipeline", {}).get("values", ["mem"]) + epilogues = trait_config.get("epilogue", {}).get("values", ["default"]) + schedulers = trait_config.get("scheduler", {}).get("values", ["intrawave"]) + pad_m_values = trait_config.get("pad_m", {}).get("values", [False]) + pad_n_values = trait_config.get("pad_n", {}).get("values", [False]) + pad_k_values = trait_config.get("pad_k", {}).get("values", [False]) + persistent_values = trait_config.get("persistent", {}).get( + "values", [False] + ) + + all_combinations = list( + itertools.product( + pipelines, + epilogues, + schedulers, + pad_m_values, + pad_n_values, + pad_k_values, + persistent_values, + ) + ) + + # Filter out unsupported trait combinations + combinations = [] + for combo in all_combinations: + pipeline, epilogue, scheduler = combo[:3] + if is_trait_combination_valid(pipeline, epilogue, scheduler): + combinations.append(combo) + else: + logging.debug( + f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}" + ) + else: + # Fallback to minimal default + combinations = [("mem", "default", "intrawave", False, False, False, False)] + + return combinations + + def _get_dtype_string(self): + """Get C++ type string for datatype""" + dtype_map = { + "fp16": "ck_tile::fp16_t", + "fp8": "ck_tile::fp8_t", + "bf16": "ck_tile::bf16_t", + "fp32": "float", + "fp64": "double", + } + return dtype_map.get(self.datatype, "float") + + _LAYOUT_MAP = { + "r": "ck_tile::tensor_layout::gemm::RowMajor", + "c": "ck_tile::tensor_layout::gemm::ColumnMajor", + } + + def _get_abc_layouts(self, layout_code: str | None = None): + """ + Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. + If layout_code is None, use self.layout. + """ + if layout_code is None: + # fall back to the instance field + layout_code = getattr(self, "layout", "") + + code = str(layout_code).strip().lower() + + if len(code) != 3 or any(ch not in self._LAYOUT_MAP for ch in code): + raise ValueError( + f"Invalid layout '{layout_code}'. " + "Use a 3-letter code with 'r'/'c' (e.g., rcr, ccr, crr, rrr)." + ) + + a_layout = self._LAYOUT_MAP[code[0]] + b_layout = self._LAYOUT_MAP[code[1]] + c_layout = self._LAYOUT_MAP[code[2]] + return a_layout, b_layout, c_layout + + def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True): + """Generate a single kernel instance""" + ( + pipeline, + epilogue, + scheduler, + pad_m, + pad_n, + pad_k, + persistent, + ) = trait_combo + + # Create kernel name with proper boolean capitalization + kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" + + # Create tile configuration string + tile_str = ( + f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + ) + tile_str += ( + f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + ) + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + kernel_name += f"_{tile_str}" + + # Map pipeline names to the correct pipeline implementation + pipeline_impl_map = { + "mem": "ck_tile::GemmPipelineAgBgCrMem", + "compv3": "ck_tile::GemmPipelineAgBgCrCompV3", + "compv4": "ck_tile::GemmPipelineAgBgCrCompV4", + } + + # Map scheduler names to the correct enum values + scheduler_type_map = { + "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", + "interwave": "ck_tile::GemmPipelineScheduler::Interwave", + "default": "ck_tile::GemmPipelineScheduler::Default", + } + + # Determine accumulator type based on datatype + acc_type = "float" + if self.datatype in ["int8", "int4"]: + acc_type = "ck_tile::int32_t" + + # Determine output type + c_type = self._get_dtype_string() + if self.datatype in ["fp8", "bf8"]: + c_type = "ck_tile::fp16_t" + + # Determine layouts based on self.layout + a_layout, b_layout, c_layout = self._get_abc_layouts() + + # Map pipeline names to base pipeline for hot loop detection + base_pipeline_map = { + "mem": "ck_tile::BaseGemmPipelineAgBgCrMem", + "compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3", + "compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4", + } + + # Generate kernel instance code using the correct API + pragma_line = "#pragma once\n" if is_header else "" + instance_code = f"""// Generated kernel instance for {kernel_name} +{pragma_line} +#include +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" + +using ADataType = {self._get_dtype_string()}; +using BDataType = {self._get_dtype_string()}; +using AccDataType = {acc_type}; +using CDataType = {c_type}; + +using ALayout = {a_layout}; +using BLayout = {b_layout}; +using CLayout = {c_layout}; + +// Kernel name for display +constexpr const char* KERNEL_NAME = "{kernel_name}"; + +// Wrapper for simplified launch interface +struct SelectedKernel {{ + // Tile configuration + static constexpr ck_tile::index_t BlockSize = 256; + static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]}; + static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]}; + static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]}; + static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]}; + static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]}; + static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]}; + static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]}; + static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]}; + static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]}; + + // Traits + static constexpr bool kPadM = {"true" if pad_m == "true" else "false"}; + static constexpr bool kPadN = {"true" if pad_n == "true" else "false"}; + static constexpr bool kPadK = {"true" if pad_k == "true" else "false"}; + static constexpr bool TransposeC = false; + static constexpr bool UsePersistentKernel = {"true" if persistent == "true" else "false"}; + static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"}; + static constexpr bool UseStructuredSparsity = false; + static constexpr bool Preshuffle = false; + static constexpr ck_tile::index_t NumWaveGroups = 1; + + // Tile shape + using TileShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence, + false, false>; + + // Tile partitioner + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; + + // Traits + using Traits = ck_tile::TileGemmTraits; + + // Pipeline problem + using GemmPipelineProblem = ck_tile::GemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + Traits>; + + // Base pipeline for hot loop detection + using BaseGemmPipeline = {base_pipeline_map.get(pipeline, "ck_tile::BaseGemmPipelineAgBgCrMem")}; + + static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ + const ck_tile::index_t k_grain = args.k_batch * TileK; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - + float ave_time{{0}}; const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{ constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = {SCHEDULER_MAP[scheduler]}; - constexpr auto memory_operation = memory_operation_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = {scheduler_type_map.get(scheduler, "ck_tile::GemmPipelineScheduler::Intrawave")}; + [[maybe_unused]] constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + ck_tile::TileGemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v>; + + using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::GemmPipelineAgBgCrCompV3")}; + + // Epilogue +""" - using GemmPipeline = {PIPELINE_MAP[pipeline][1]}; - {EPILOGUE_MAP[epilogue]} - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + # Add epilogue configuration based on type + if epilogue == "cshuffle": + instance_code += """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, // kM_ + TilePartitioner::NPerBlock, // kN_ + WarpPerBlock_M, // MWave_ + WarpPerBlock_N, // NWave_ + WarpTileM, // MPerXdl_ + WarpTileN, // NPerXdl_ + WarpTileK, // KPerXdl_ + TransposeC, // isCTransposed_ + memory_operation, // MemoryOperation_ + NumWaveGroups>; // kNumWaveGroups_ + + using GemmEpilogue = ck_tile::CShuffleEpilogue; +""" + else: # default epilogue + instance_code += """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, // kM_ + TilePartitioner::NPerBlock, // kN_ + kPadM, + kPadN, + WarpTileM, // kMPerXdl_ + WarpTileN, // kNPerXdl_ + WarpTileK, // kKPerXdl_ + TransposeC>; // isCTransposed_ + + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue; +""" - if(!Kernel::IsSupportedArgument(kargs)) - {{ + instance_code += f""" + + // Kernel type + using GemmKernel = ck_tile::GemmKernel; + + // Make kernel arguments + auto kargs = GemmKernel::MakeKernelArgs(args); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); }} - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = {'Kernel::MaxOccupancyGridSize(stream)' if persistent == 'true' else 'Kernel::GridSize(args.M, args.N, args.k_batch)'}; - - if(stream.log_level_ > 0) - {{ - std::cout << "Launching kernel with args:" - << " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" - << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" - << std::endl; - }} - - if(stream.flush_cache_) - {{ - std::cout << "Flushing cache..." << std::endl; - static constexpr ck_tile::index_t APackedSize = - std::is_same_v ? 2 : 1; - static constexpr ck_tile::index_t BPackedSize = - std::is_same_v ? 2 : 1; - - auto is_row_major = [](auto layout_) {{ - return ck_tile::bool_constant, - ck_tile::tensor_layout::gemm::RowMajor>>{{}}; - }}; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{{}}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{{}}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; - auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; - - ck_tile::RotatingMemWrapper rotating_mem( - kargs.as_ptr[0], kargs.bs_ptr[0], stream.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() {{ - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_)); - }}; - ave_time = ck_tile::launch_kernel_time_mask( - stream, - run_flush_cache, - ck_tile::make_kernel( - Kernel{{}}, grids, blocks, 0, kargs)); - }} - else{{ - ave_time = ck_tile::launch_kernel(stream, - ck_tile::make_kernel( - Kernel{{}}, grids, blocks, 0, kargs)); + + // Get grid and block sizes + const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent == "true" else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; + const dim3 blocks = GemmKernel::BlockSize(); + + if(stream.log_level_ > 0) {{ + std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n' + << "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" + << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" + << std::endl; }} + + // Launch kernel + constexpr int kBlockPerCu = 1; + ave_time = ck_tile::launch_kernel( + stream, + ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + return ave_time; - }}; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{ @@ -373,484 +609,324 @@ struct GemmKernel {{ }}; BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; }} - - static std::string get_name() {{ - return std::string("gemm_") + std::to_string(TileM) + "x" + std::to_string(TileN) + "x" + std::to_string(TileK) + - "_" + std::to_string(WarpM) + "x" + std::to_string(WarpN) + "x" + std::to_string(WarpK) + "_" + - std::to_string(WarpTileM) + "x" + std::to_string(WarpTileN) + "x" + std::to_string(WarpTileK) + "_" + - "{pad_m}" + "_" + - "{pad_n}" + "_" + - "{pad_k}" + "_" + - "{pipeline}" + "_" + - "{epilogue}" + "_" + - "{scheduler}" + "_" + - "{persistent}"; - }} }}; """ - def _generate_common_instance_header_file(self): - """Generate common instance header into file.""" - content = """// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#pragma once -""" - for trait in self.valid_trait_names: - content += f'#include "gemm_{trait}.hpp"\n' - (self.output_dir / "gemm_instances.hpp").write_text(content) + return kernel_name, instance_code - def is_tile_valid(self, tile: tuple, trait: str) -> bool: - """Check if the tile configuration is valid for the given trait.""" - ( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - ) = tile - pipeline, *_ = trait.split("_") + def generate_individual(self, num_workers=None): + """Generate individual kernel files for separate compilation with parallel processing""" + if num_workers is None: + num_workers = min( + multiprocessing.cpu_count(), 8 + ) # Limit to avoid memory issues - # Parameter validity check - invalid_params = [] - if (warp_m, warp_n, warp_k) not in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]: - invalid_params.append( - f"warp_m({warp_m}) * warp_n({warp_n}) * warp_k({warp_k})" - ) - if (warp_m * warp_tile_m) == 0: - invalid_params.append(f"warp_m({warp_m}) * warp_tile_m({warp_tile_m})") - if (warp_n * warp_tile_n) == 0: - invalid_params.append(f"warp_n({warp_n}) * warp_tile_n({warp_tile_n})") - if (warp_k * warp_tile_k) == 0: - invalid_params.append(f"warp_k({warp_k}) * warp_tile_k({warp_tile_k})") + tile_configs = self._get_tile_configs() + trait_combos = self._generate_trait_combinations() - if invalid_params: - logging.debug( - f"Trait: [{trait}], Invalid warp configuration: {', '.join(invalid_params)}. " - f"Parameter combination: warp=({warp_m},{warp_n},{warp_k}), " - f"warp_tile=({warp_tile_m},{warp_tile_n},{warp_tile_k})" - ) - return False - # Dimension alignment check - alignment_issues = [] - if tile_m % (warp_m * warp_tile_m) != 0: - alignment_issues.append( - f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}" - ) - if tile_n % (warp_n * warp_tile_n) != 0: - alignment_issues.append( - f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}" - ) - if tile_k % (warp_k * warp_tile_k) != 0: - alignment_issues.append( - f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}" - ) - - if alignment_issues: - logging.debug( - f"Trait: [{trait}], Dimension alignment failed: {', '.join(alignment_issues)}. " - f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by " - f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" - ) - return False - - # LDS capacity verification - matrix_a_size = (tile_m * tile_k) * element_size( - self.config.problem.datatype_map["matrix_a"] - ) - matrix_b_size = (tile_n * tile_k) * element_size( - self.config.problem.datatype_map["matrix_b"] - ) - total_tile_in_lds = matrix_a_size + matrix_b_size - - max_tile_size = 2**15 if pipeline == "compv4" else 2**16 - - if total_tile_in_lds > max_tile_size: - logging.debug( - f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > " - f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n" - f"- Matrix A ({self.config.problem.datatype_map['matrix_a']}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" - f"- Matrix B ({self.config.problem.datatype_map['matrix_b']}): {tile_n}x{tile_k} = {matrix_b_size:,}B" - ) - return False - - # Warp combination validation - warp_tile_key = f"{self.config.problem.datatype_map['matrix_a']}_{self.config.problem.datatype_map['matrix_b']}_{self.config.problem.datatype_map['matrix_c']}" - current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] - - gpu_name = get_gpu_name_by_id(0) - - gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {}) - if not gpu_warp_tile_key: - logging.debug( - f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check." - ) - return False - - allowed_combinations = gpu_warp_tile_key.get(warp_tile_key, []) - if not allowed_combinations: - logging.debug( - f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check." - ) - return False - - if current_combination not in allowed_combinations: - logging.debug( - f"Trait: [{trait}], Invalid warp combination: {current_combination} not in allowed list. " - f"Valid combinations for data type '{warp_tile_key}': {allowed_combinations}" - ) - return False - - return True - - def _get_valid_trait_tile_combinations(self): - def get_tile_value(tile_param): - return ( - tile_param.generate_candidates() - if isinstance(tile_param, RangeConfigParam) - else tile_param.values - ) - - tile_group = list( - itertools.product( - get_tile_value(self.config.tile_config.tile_m), - get_tile_value(self.config.tile_config.tile_n), - get_tile_value(self.config.tile_config.tile_k), - ) - ) - - warp_group = list( - itertools.product( - get_tile_value(self.config.tile_config.warp_m), - get_tile_value(self.config.tile_config.warp_n), - get_tile_value(self.config.tile_config.warp_k), - ) - ) - - warp_tile_group = list( - itertools.product( - get_tile_value(self.config.tile_config.warp_tile_m), - get_tile_value(self.config.tile_config.warp_tile_n), - get_tile_value(self.config.tile_config.warp_tile_k), - ) - ) - - tile_params = { - t + w + wt for t in tile_group for w in warp_group for wt in warp_tile_group - } - - for trait in self.valid_trait_names: - tile_valid_params = [ - tile for tile in tile_params if self.is_tile_valid(tile, trait) - ] - - if trait not in self.valid_trait_tile_combinations: - self.valid_trait_tile_combinations[trait] = [] - self.valid_trait_tile_combinations[trait].append(tile_valid_params) - - def _generate_instantiation_source_files(self): - """Generate kernel instance instantiation source files""" - tile_map = {} - for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): - for tile in tile_valid_params: - for ( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - ) in tile: - key = f"{tile_m}x{tile_n}x{tile_k}x{warp_m}x{warp_n}x{warp_k}" - value = f"{warp_tile_m}x{warp_tile_n}x{warp_tile_k}" - if key not in tile_map: - tile_map[key] = set() - tile_map[key].add(value) - - files_listed = 0 - for trait, _ in self.valid_trait_tile_combinations.items(): - for block_tile, warp_tiles in tile_map.items(): - tile_m, tile_n, tile_k, warp_m, warp_n, warp_k = map( - int, block_tile.split("x") + # Prepare work items for parallel processing + work_items = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + work_items.append( + ( + tile_config, + trait_combo, + self.working_path, + self.datatype, + self.layout, + ) ) - content = f""" -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - - -#include "gemm_{trait}.hpp" - -""" - for warp_tile in warp_tiles: - warp_tile_m, warp_tile_n, warp_tile_k = map( - int, warp_tile.split("x") - ) - - sparse = ( - self.config.problem.datatype_map["matrix_a"] == "fp16" - and self.config.problem.datatype_map["matrix_b"] == "fp16" - and self.config.problem.datatype_map["matrix_c"] == "fp16" - and ( - ( - warp_tile_m == 32 - and warp_tile_n == 32 - and warp_tile_k == 16 - ) - or ( - warp_tile_m == 16 - and warp_tile_n == 16 - and warp_tile_k == 32 - ) - ) - ) - if sparse: - files_listed = files_listed + 1 - content = ( - content - + f""" -template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, true>;""" - ) - files_listed = files_listed + 1 - content = ( - content - + f""" -template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, false>;""" - ) - content += f""" -""" - ( - self.output_dir - / f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp" - ).write_text(content) - print(f"Generated {files_listed} kernel instances in total.") - - def _generate_dispatcher_file(self): - """Generate the code block of dispatch mechanism.""" - content = """ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include - -#include "gemm_common.hpp" -#include "gemm_instances.hpp" - -/// @brief Defines the configuration parameters for a GEMM operation, enabling the selection of a -/// specific kernel instance based on the provided settings. -struct KernelTraits -{ - /// @brief The name of the pipeline. - std::string pipeline; - /// @brief The name of the scheduler (e.g., "intrawave", "interwave"). - std::string scheduler; - /// @brief The name of the epilogue (e.g., "cshuffle", "default"). - std::string epilogue; - /// @brief Indicates whether padding is applied to the M dimension. - bool pad_m; - /// @brief Indicates whether padding is applied to the N dimension. - bool pad_n; - /// @brief Indicates whether padding is applied to the K dimension. - bool pad_k; - /// @brief Indicates whether the kernel is persistent. - bool persistent; -}; - -struct GemmDispatcher { - static auto& get_kernel_map() { - // Use a static local variable - static std::unordered_map< - std::string, - std::vector(ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>> - kernel_map; - return kernel_map; - } - - static void init([[maybe_unused]]bool structured_sparsity) { - auto& kernel_map = get_kernel_map(); - if(!kernel_map.empty()) return; - \n""" - - for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): - content += f""" kernel_map["{trait}"] = {{""" - for _, tile in enumerate(tile_valid_params): - for j in range(len(tile)): - ( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - ) = tile[j] - content += f"""[=](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """ - content += f""" - if(structured_sparsity){{ // SMFMA""" - sparse = ( - self.config.problem.datatype_map["matrix_a"] == "fp16" - and self.config.problem.datatype_map["matrix_b"] == "fp16" - and self.config.problem.datatype_map["matrix_c"] == "fp16" - and ( - ( - warp_tile_m == 32 - and warp_tile_n == 32 - and warp_tile_k == 16 - ) - or ( - warp_tile_m == 16 - and warp_tile_n == 16 - and warp_tile_k == 32 - ) - ) - ) - content += f""" - return run_kernel<{trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, {BOOL_MAP(sparse)}>>(args, stream);""" - content += f""" - }} else {{""" - content += f""" - return run_kernel<{trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, {BOOL_MAP(False)}>>(args, stream);""" - content += f""" - }} """ - - if j == len(tile) - 1: - content += f""" - }} """ - else: - content += f""" - }}, """ - content += f""" - }};\n """ - - content += """ } - - template - static std::tuple run_kernel(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) - { - std::string name = Kernel::get_name(); - float avg_time = Kernel::launch(args, stream); - - return std::make_tuple(name, avg_time); - } - - - static auto dispatch(bool structured_sparsity, const KernelTraits& trait) { - init(structured_sparsity); - const std::string key = assemble_key(trait); - auto& kernel_map = get_kernel_map(); - if(auto it = kernel_map.find(key); it != kernel_map.end()) - { - return it->second; - } - throw std::runtime_error("No suitable kernel found: " + key); - } - -private: - static std::string assemble_key(const KernelTraits &trait) { - return std::string(trait.pipeline) + "_" + - trait.epilogue + "_" + - trait.scheduler + "_" + - (trait.pad_m ? "true" : "false") + "_" + - (trait.pad_n ? "true" : "false") + "_" + - (trait.pad_k ? "true" : "false") + "_" + - (trait.persistent ? "true" : "false"); - } -}; - -""" - (self.output_dir / "gemm_dispatcher.hpp").write_text(content) - - -def do_list_blobs( - args: argparse.Namespace, user_provide_config: Optional[GemmConfig] = None -): - generator = GemmCodeGenerator(args.working_path, user_provide_config) - generator.list_all_trait_names() - - -def do_gen_blobs( - args: argparse.Namespace, user_provide_config: Optional[GemmConfig] = None -): - generator = GemmCodeGenerator(args.working_path, user_provide_config) - generator.generate_all_instance_files() - - -def main(args): - gemm_config = ( - GemmConfig.from_json(args.config_json, args.datatype, args.layout) - if args.config_json is not None - else args.config_json - ) - - if args.list_blobs: - do_list_blobs(args, gemm_config) - elif args.gen_blobs: - do_gen_blobs(args, gemm_config) - else: - logging.warning( - "No mode specified (use --list_blobs or --gen_blobs). Generating by default..." + print( + f"Generating {len(work_items)} individual kernel files using {num_workers} workers..." ) - do_gen_blobs(args, gemm_config) + print(f" Tile configs: {len(tile_configs)}") + print(f" Trait combinations: {len(trait_combos)}") + print(f" Total kernels: {len(work_items)}") + + # Show first few work items for debugging + if work_items: + print(" First work item example:") + tile_config, trait_combo = work_items[0][:2] + print(f" Tile config: {tile_config}") + print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits + + # Process work items in parallel + kernel_list = [] + completed = 0 + + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_workers + ) as executor: + # Submit all work items + print(f" Submitting {len(work_items)} tasks to executor...") + future_to_item = { + executor.submit(_generate_single_kernel_individual, item): item + for item in work_items + } + print(" All tasks submitted, waiting for completion...") + + # Collect results with progress reporting + for future in concurrent.futures.as_completed(future_to_item): + completed += 1 + if completed % 100 == 0 or completed == len(work_items): + print( + f" Progress: {completed}/{len(work_items)} kernels generated" + ) + + try: + result = future.result() + if result: + kernel_list.append(result) + except Exception as exc: + item = future_to_item[future] + print(f"Kernel generation failed for {item}: {exc}") + + # Sort kernel list for consistent ordering + kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name + + # Generate CMake include file for individual targets + self._generate_cmake_individual_targets(kernel_list) + + print( + f"Generated {len(kernel_list)} individual kernel files in {self.working_path}" + ) + + def _generate_cmake_individual_targets(self, kernel_list): + """Generate CMake include file that creates individual targets""" + cmake_code = f"""# Generated CMake file for individual GEMM targets +# Datatype: {self.datatype}, Layout: {self.layout} + +""" + + for kernel_name, trait_combo, tile_config in kernel_list: + pipeline, epilogue, scheduler = trait_combo[:3] + + # Format tile config for CMake function + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + trait_str = f"{pipeline}_{epilogue}_{scheduler}_" + "_".join( + str(x) for x in trait_combo[3:] + ) + + cmake_code += f'create_individual_gemm_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n' + + # Write CMake include file + with open(self.working_path / "gemm_individual_targets.cmake", "w") as f: + f.write(cmake_code) + + def write_kernel_list(self): + """Write kernel list to file for CMake to read (with comprehensive validation)""" + # Get configurations using comprehensive validation + tile_configs = self._get_tile_configs(fast_mode=False) + trait_combos = self._generate_trait_combinations() + + kernel_list = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + ( + pipeline, + epilogue, + scheduler, + pad_m, + pad_n, + pad_k, + persistent, + ) = trait_combo + + # Create kernel name with proper boolean capitalization + kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" + + # Create tile configuration string + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + kernel_name += f"_{tile_str}" + + kernel_list.append( + { + "name": kernel_name, + "tile_config": tile_config, + "trait_combo": trait_combo, + } + ) + + # Write kernel count + with open(self.working_path / "gemm_kernel_count.txt", "w") as f: + f.write(str(len(kernel_list))) + + # Write kernel list + with open(self.working_path / "gemm_kernel_list.txt", "w") as f: + for kernel in kernel_list: + # Format: kernel_name|tile_config|trait_combo + tile_config = kernel["tile_config"] + trait_combo = kernel["trait_combo"] + + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + trait_str = ( + f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_" + + "_".join(str(x) for x in trait_combo[3:]) + ) + + f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n") + + print(f"Listed {len(kernel_list)} kernel configurations") + + def run(self, num_workers=None): + """Run the builder to generate individual kernel files""" + # Generate individual kernel files + self.generate_individual(num_workers) -if __name__ == "__main__": +def _generate_single_kernel_individual(work_item): + """Worker function to generate a single individual kernel file""" + tile_config, trait_combo, working_path, datatype, layout = work_item + + # Create a temporary builder instance for this worker + builder = GemmKernelBuilder(working_path, datatype, layout) + + try: + kernel_name, instance_code = builder._generate_kernel_instance( + tile_config, trait_combo + ) + + # Create simplified filename without the "gemm_" prefix + # Remove "gemm_" from the beginning of kernel_name for the filename + simplified_name = kernel_name + if simplified_name.startswith("gemm_"): + simplified_name = simplified_name[5:] # Remove "gemm_" prefix + + # Write individual header file + header_file = working_path / f"gemm_single_{simplified_name}.hpp" + with open(header_file, "w") as f: + f.write(instance_code) + + return (kernel_name, trait_combo, tile_config) + except Exception as e: + print(f"Error generating individual kernel: {e}") + return None + + +def main(): parser = argparse.ArgumentParser( - prog="generate", - description="gen API for CK gemm kernel", + description="GEMM kernel instance builder with parallel support" ) + parser.add_argument("--working_path", required=True, help="Working directory path") parser.add_argument( - "-w", - "--working_path", - default="./", - required=False, - help="The path where all the blobs are going to be generated", - ) - parser.add_argument( - "-j", - "--config_json", - required=False, - help="Path to the json which contains the configurations that user provide", - ) - parser.add_argument( - "-d", "--datatype", required=True, - help="Specify what datatype to use for the kernel generation, e.g. fp16, bf16, int8, fp8, bf8", + choices=["fp16", "fp8", "bf16", "fp32", "fp64"], + help="Data type", ) parser.add_argument( - "-ly", "--layout", required=True, - help="Specify what layout to use for the kernel generation, e.g. rcr, rrr", + choices=["rcr", "rrr", "ccr", "crr"], + help="Matrix layout", + ) + parser.add_argument("--config_json", help="Configuration JSON file") + parser.add_argument( + "--num_workers", type=int, help="Number of parallel workers (default: auto)" ) parser.add_argument( - "-l", - "--list_blobs", - action="store_true", - help="List all kernel instances to file", + "--gen_individual", action="store_true", help="Generate individual kernel files" ) parser.add_argument( - "-g", - "--gen_blobs", + "--gen_single", action="store_true", help="Generate a single kernel file" + ) + parser.add_argument("--kernel_name", help="Kernel name for single generation") + parser.add_argument( + "--tile_config", help="Tile configuration string for single generation" + ) + parser.add_argument( + "--trait_combo", help="Trait combination string for single generation" + ) + parser.add_argument( + "--list_kernels", action="store_true", - help="Generate all kernel instances into different files", + help="List kernel configurations without generating files", ) args = parser.parse_args() - main(args) + # Create builder + builder = GemmKernelBuilder( + args.working_path, args.datatype, args.layout, args.config_json + ) + + if args.list_kernels: + # Fast listing mode - just write kernel list without generating files + builder.write_kernel_list() + elif args.gen_single: + # Generate a single kernel file + if not args.kernel_name or not args.tile_config or not args.trait_combo: + parser.error( + "--gen_single requires --kernel_name, --tile_config, and --trait_combo" + ) + + # Parse tile config + tile_parts = args.tile_config.split("_") + tile_dims = tile_parts[0].split("x") + warp_dims = tile_parts[1].split("x") + warp_tile_dims = tile_parts[2].split("x") + + tile_config = { + "tile_m": int(tile_dims[0]), + "tile_n": int(tile_dims[1]), + "tile_k": int(tile_dims[2]), + "warp_m": int(warp_dims[0]), + "warp_n": int(warp_dims[1]), + "warp_k": int(warp_dims[2]), + "warp_tile_m": int(warp_tile_dims[0]), + "warp_tile_n": int(warp_tile_dims[1]), + "warp_tile_k": int(warp_tile_dims[2]), + } + + # Parse trait combo + trait_parts = args.trait_combo.split("_") + trait_combo = ( + trait_parts[0], # pipeline + trait_parts[1], # epilogue + trait_parts[2], # scheduler + trait_parts[3] == "True", # pad_m + trait_parts[4] == "True", # pad_n + trait_parts[5] == "True", # pad_k + trait_parts[6] == "True", # persistent + ) + + # Generate the kernel + kernel_name, instance_code = builder._generate_kernel_instance( + tile_config, trait_combo + ) + + # Write the file + simplified_name = kernel_name + if simplified_name.startswith("gemm_"): + simplified_name = simplified_name[5:] + + header_file = builder.working_path / f"gemm_single_{simplified_name}.hpp" + with open(header_file, "w") as f: + f.write(instance_code) + + print(f"Generated {header_file}") + + elif args.gen_individual: + # Generate all individual kernel files + builder.run(args.num_workers) + else: + parser.error( + "Must specify one of: --list_kernels, --gen_individual, or --gen_single" + ) + + +if __name__ == "__main__": + main() diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_profiler.hpp index 634e19de6e..bbf0c92e67 100644 --- a/tile_engine/ops/gemm/gemm_profiler.hpp +++ b/tile_engine/ops/gemm/gemm_profiler.hpp @@ -20,6 +20,25 @@ class GemmProfiler return instance; } + // Overload for single kernel benchmarking + void benchmark(GemmProblem& gemm_problem, + std::function + kernel_func) + { + // Create a vector with a single callable that returns both name and time + std::vector(ck_tile::GemmHostArgs&, + const ck_tile::stream_config&)>> + callables; + + callables.push_back( + [kernel_func](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) { + float time = kernel_func(args, stream); + return std::make_tuple(std::string(KERNEL_NAME), time); + }); + + benchmark(gemm_problem, callables); + } + void benchmark(GemmProblem& gemm_problem, std::vector( ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables) @@ -161,7 +180,7 @@ class GemmProfiler kernel_instance.perf_result_.tflops_ = static_cast(flop) / 1.E9 / avg_time; kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time; - if(setting_.log_ > 0) + if(setting_.log_ > 0 && !setting_.json_output_) { std::cout << kernel_instance << std::endl; } @@ -199,10 +218,18 @@ class GemmProfiler b.perf_result_, a.perf_result_, metric); }); - std::cout << "**********************************" << std::endl; - std::cout << "According to given metrics: " << get_metric_name(metric) << "\n" - << "The best kernel instance is: " << kernel_instance << std::endl; - std::cout << "**********************************" << std::endl; + if(setting_.json_output_) + { + // Output clean JSON only + std::cout << kernel_instance << std::endl; + } + else + { + std::cout << "**********************************" << std::endl; + std::cout << "According to given metrics: " << get_metric_name(metric) << "\n" + << "Current kernel performance is: " << kernel_instance << std::endl; + std::cout << "**********************************" << std::endl; + } if(!setting_.csv_filename_.empty()) { diff --git a/tile_engine/ops/gemm/test_benchmark.sh b/tile_engine/ops/gemm/test_benchmark.sh new file mode 100755 index 0000000000..1fb7c163af --- /dev/null +++ b/tile_engine/ops/gemm/test_benchmark.sh @@ -0,0 +1,102 @@ +#!/bin/bash + +# Test script for tile engine GEMM benchmarks +# This script demonstrates how to run the new individual benchmark executables + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Find the build directory +if [ -z "$1" ]; then + # Try to find build directory automatically + BUILD_DIR=$(find /root/workspace/composable_kernel -name "test_gemm_fix" -type d 2>/dev/null | head -1) + if [ -z "$BUILD_DIR" ]; then + echo -e "${RED}Error: Could not find build directory. Please provide it as first argument.${NC}" + echo "Usage: $0 " + exit 1 + fi +else + BUILD_DIR="$1" +fi + +echo -e "${GREEN}Using build directory: $BUILD_DIR${NC}" + +# Check if bin directory exists +if [ ! -d "$BUILD_DIR/bin" ]; then + echo -e "${RED}Error: bin directory not found in $BUILD_DIR${NC}" + exit 1 +fi + +# Find all benchmark executables +echo -e "${YELLOW}Finding benchmark executables...${NC}" +BENCHMARKS=$(find "$BUILD_DIR/bin" -name "benchmark_gemm_*" -type f 2>/dev/null) + +if [ -z "$BENCHMARKS" ]; then + echo -e "${RED}No benchmark executables found in $BUILD_DIR/bin${NC}" + echo "Please build some benchmarks first with:" + echo " cd $BUILD_DIR" + echo " make benchmark_gemm_" + exit 1 +fi + +# Count benchmarks +NUM_BENCHMARKS=$(echo "$BENCHMARKS" | wc -l) +echo -e "${GREEN}Found $NUM_BENCHMARKS benchmark executable(s)${NC}" + +# Test sizes +SIZES=(512 1024 2048) + +# Results file +RESULTS_FILE="benchmark_results_$(date +%Y%m%d_%H%M%S).csv" + +echo -e "${YELLOW}Running benchmarks...${NC}" +echo "Results will be saved to: $RESULTS_FILE" + +# Run each benchmark +COUNTER=0 +for BENCH in $BENCHMARKS; do + COUNTER=$((COUNTER + 1)) + BENCH_NAME=$(basename "$BENCH") + echo -e "\n${GREEN}[$COUNTER/$NUM_BENCHMARKS] Running: $BENCH_NAME${NC}" + + for SIZE in "${SIZES[@]}"; do + echo -e " Testing size: ${SIZE}x${SIZE}x${SIZE}" + + # Run with verification + "$BENCH" -m=$SIZE -n=$SIZE -k=$SIZE -verify=2 -warmup=10 -repeat=20 \ + -csv_filename="$RESULTS_FILE" -csv_format=simple \ + 2>&1 | grep -E "(Time:|Performance:|Verification:|Error)" + + if [ ${PIPESTATUS[0]} -ne 0 ]; then + echo -e " ${RED}Benchmark failed!${NC}" + fi + done +done + +echo -e "\n${GREEN}Benchmark testing complete!${NC}" +echo "Results saved to: $RESULTS_FILE" + +# Show summary if CSV file exists +if [ -f "$RESULTS_FILE" ]; then + echo -e "\n${YELLOW}Summary of results:${NC}" + echo "Number of tests: $(tail -n +2 "$RESULTS_FILE" | wc -l)" + echo "Successful tests: $(grep -c "true" "$RESULTS_FILE")" + echo "Failed tests: $(grep -c "false" "$RESULTS_FILE")" +fi + +# Example of running a specific benchmark with different options +echo -e "\n${YELLOW}Example commands for manual testing:${NC}" +echo "# Basic run:" +echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024" +echo "" +echo "# With CPU verification:" +echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024 -verify=1" +echo "" +echo "# JSON output for parsing:" +echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024 -json_output=true" +echo "" +echo "# Performance testing with TFLOPS metric:" +echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=4096 -n=4096 -k=4096 -warmup=100 -repeat=200 -metric=1" diff --git a/tile_engine/ops/gemm/test_validation.py b/tile_engine/ops/gemm/test_validation.py new file mode 100644 index 0000000000..1c9a0ff0ca --- /dev/null +++ b/tile_engine/ops/gemm/test_validation.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python +""" +Test script to verify that the validation logic is working correctly. +""" + +from validation_utils import ( + is_tile_config_valid, + is_trait_combination_valid, + validate_warp_tile_combination, + get_gpu_name_by_id, +) + + +def test_warp_tile_validation(): + """Test warp tile combination validation""" + print("Testing warp tile combination validation...") + + # Get GPU name + gpu_name = get_gpu_name_by_id(0) + print(f"Detected GPU: {gpu_name}") + + # Test cases for fp16 + test_cases = [ + # (warp_tile_m, warp_tile_n, warp_tile_k, expected_valid) + ([4, 64, 8], False), # Invalid - not in supported list + ([4, 64, 16], True), # Valid + ([32, 32, 8], True), # Valid + ([16, 16, 16], True), # Valid + ([32, 32, 16], True), # Valid + ([16, 16, 32], True), # Valid + ([64, 4, 16], True), # Valid + ([128, 128, 128], False), # Invalid - too large + ] + + print("\nTesting fp16 warp tile combinations:") + for (warp_tile_m, warp_tile_n, warp_tile_k), expected in test_cases: + valid, msg = validate_warp_tile_combination( + warp_tile_m, warp_tile_n, warp_tile_k, "fp16", "fp16", "fp16", gpu_name + ) + status = "PASS" if valid == expected else "FAIL" + print(f" [{warp_tile_m}, {warp_tile_n}, {warp_tile_k}]: {valid} - {status}") + if not valid and msg: + print(f" Reason: {msg}") + + +def test_trait_combinations(): + """Test trait combination validation""" + print("\n\nTesting trait combination validation...") + + test_cases = [ + # (pipeline, epilogue, scheduler, expected_valid) + ("mem", "default", "intrawave", True), + ("mem", "cshuffle", "intrawave", True), + ("compv3", "default", "interwave", False), # Invalid combination + ("compv3", "cshuffle", "interwave", False), # Invalid combination + ("compv4", "default", "interwave", False), # Invalid combination + ("compv4", "cshuffle", "interwave", False), # Invalid combination + ("compv3", "default", "intrawave", True), + ("compv4", "cshuffle", "intrawave", True), + ] + + print("\nTesting trait combinations:") + for pipeline, epilogue, scheduler, expected in test_cases: + valid = is_trait_combination_valid(pipeline, epilogue, scheduler) + status = "PASS" if valid == expected else "FAIL" + print(f" {pipeline}-{epilogue}-{scheduler}: {valid} - {status}") + + +def test_full_tile_config_validation(): + """Test full tile configuration validation""" + print("\n\nTesting full tile configuration validation...") + + # Test case that was failing in the build + tile_m, tile_n, tile_k = 256, 256, 32 + warp_m, warp_n, warp_k = 1, 4, 1 + warp_tile_m, warp_tile_n, warp_tile_k = 4, 64, 8 + + print("\nTesting problematic configuration:") + print(f" Tile: {tile_m}x{tile_n}x{tile_k}") + print(f" Warp: {warp_m}x{warp_n}x{warp_k}") + print(f" WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}") + + valid = is_tile_config_valid( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + "fp16", + "fp16", + "fp16", + "mem", + ) + + print(f" Valid: {valid}") + print(" Expected: False (warp tile [4, 64, 8] is not supported for fp16)") + + # Test a valid configuration + warp_tile_k = 16 # Change to valid value + print("\nTesting corrected configuration:") + print(f" WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}") + + valid = is_tile_config_valid( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + "fp16", + "fp16", + "fp16", + "mem", + ) + + print(f" Valid: {valid}") + print(" Expected: True") + + +def main(): + """Run all tests""" + print("=" * 60) + print("GEMM Validation Test Suite") + print("=" * 60) + + test_warp_tile_validation() + test_trait_combinations() + test_full_tile_config_validation() + + print("\n" + "=" * 60) + print("Test suite completed") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/tile_engine/ops/gemm/validation_utils.py b/tile_engine/ops/gemm/validation_utils.py new file mode 100644 index 0000000000..4948fd5744 --- /dev/null +++ b/tile_engine/ops/gemm/validation_utils.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python +# SPDX-License-Identifier: MIT +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +""" +Validation utilities for GEMM kernel generation. +Extracted from tile_engine_develop for consistency. +""" + +import subprocess +import re +from functools import lru_cache +import logging + +# Element size mapping for different data types +ELEMENT_SIZE_MAP = { + "fp16": 2, + "bf16": 2, + "int8": 1, + "fp8": 1, + "bf8": 1, + "int4": 0.5, + "int32": 4, + "fp32": 4, + "fp64": 8, +} + +# Supported warp tile combinations for different GPU architectures and data types +WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx90a": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], + }, + "gfx942": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], + }, + "gfx950": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "bf8_bf8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 64], + [16, 16, 32], + [16, 16, 128], + [32, 32, 64], + ], + }, +} + +# Unsupported trait combinations +TRAIT_UNSUPPORTED_COMBINATIONS = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), +} + + +def element_size(data_type: str) -> float: + """Calculate the size (in bytes) of a single element for given data type.""" + data_type = data_type.lower() + if data_type not in ELEMENT_SIZE_MAP: + raise ValueError(f"Unsupported data type: {data_type}") + return ELEMENT_SIZE_MAP[data_type] + + +GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)") + + +@lru_cache(maxsize=1) +def get_gpu_name_by_id(gpu_id: int = 0) -> str: + """Retrieve GPU name (e.g. gfx90a) by device ID""" + try: + output = subprocess.check_output( + ["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5 + ) + if matches := GPU_NAME_PATTERN.finditer(output): + gpu_list = [m.group(1) for m in matches] + return gpu_list[gpu_id] if gpu_id < len(gpu_list) else "" + + return "" + + except subprocess.CalledProcessError as e: + logging.debug(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}") + except FileNotFoundError: + logging.debug("ROCm tools not installed (requires rocminfo)") + except subprocess.TimeoutExpired: + logging.debug("GPU query timeout (5s)") + except Exception as e: + logging.debug(f"GPU detection error: {str(e)}") + + return "" + + +def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool: + """Check if a trait combination is valid.""" + return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS + + +def validate_warp_configuration(warp_m: int, warp_n: int, warp_k: int) -> bool: + """Validate warp configuration.""" + return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + + +def validate_dimension_alignment( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, +) -> tuple[bool, list[str]]: + """Check if tile dimensions are properly aligned with warp dimensions.""" + alignment_issues = [] + + if tile_m % (warp_m * warp_tile_m) != 0: + alignment_issues.append( + f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}" + ) + if tile_n % (warp_n * warp_tile_n) != 0: + alignment_issues.append( + f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}" + ) + if tile_k % (warp_k * warp_tile_k) != 0: + alignment_issues.append( + f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}" + ) + + return len(alignment_issues) == 0, alignment_issues + + +def validate_lds_capacity( + tile_m: int, + tile_n: int, + tile_k: int, + a_datatype: str, + b_datatype: str, + pipeline: str, +) -> tuple[bool, str]: + """Validate LDS capacity requirements.""" + matrix_a_size = (tile_m * tile_k) * element_size(a_datatype) + matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) + total_tile_in_lds = matrix_a_size + matrix_b_size + + max_tile_size = 2**15 if pipeline == "compv4" else 2**16 + + if total_tile_in_lds > max_tile_size: + error_msg = ( + f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > " + f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n" + f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" + f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B" + ) + return False, error_msg + + return True, "" + + +def validate_warp_tile_combination( + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + gpu_name: str = None, +) -> tuple[bool, str]: + """Validate warp tile combination against GPU-specific supported combinations.""" + if gpu_name is None: + gpu_name = get_gpu_name_by_id(0) + + # Construct the key for looking up supported combinations + warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}" + current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] + + # Check if we have GPU-specific combinations + gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {}) + if not gpu_warp_tile_combinations: + # If GPU not recognized, try to be permissive but log warning + logging.warning(f"No warp tile combinations found for GPU: {gpu_name}") + return True, "" + + # Check if we have combinations for this data type combination + allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, []) + if not allowed_combinations: + # For data type combinations not in the list, be permissive + logging.debug( + f"No warp tile combinations found for data types: {warp_tile_key}" + ) + return True, "" + + # Check if current combination is in the allowed list + if current_combination not in allowed_combinations: + error_msg = ( + f"Invalid warp tile combination: {current_combination} not in allowed list. " + f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}" + ) + return False, error_msg + + return True, "" + + +def is_tile_config_valid( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + pipeline: str, + trait_name: str = None, +) -> bool: + """ + Comprehensive tile configuration validation. + Returns True if configuration is valid, False otherwise. + """ + # Basic sanity checks + if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: + return False + if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: + return False + if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: + return False + + # Check that warp tiles fit within block tiles + if warp_m * warp_tile_m > tile_m: + return False + if warp_n * warp_tile_n > tile_n: + return False + if warp_k * warp_tile_k > tile_k: + return False + + # Validate warp configuration + if not validate_warp_configuration(warp_m, warp_n, warp_k): + logging.debug( + f"Invalid warp configuration: warp_m({warp_m}), warp_n({warp_n}), warp_k({warp_k})" + ) + return False + + # Validate dimension alignment + is_aligned, alignment_issues = validate_dimension_alignment( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + ) + if not is_aligned: + logging.debug( + f"Dimension alignment failed: {', '.join(alignment_issues)}. " + f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by " + f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + ) + return False + + # Validate LDS capacity + lds_valid, lds_error = validate_lds_capacity( + tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline + ) + if not lds_valid: + logging.debug(f"LDS validation failed: {lds_error}") + return False + + # Validate warp tile combination + warp_tile_valid, warp_tile_error = validate_warp_tile_combination( + warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype + ) + if not warp_tile_valid: + logging.debug(f"Warp tile validation failed: {warp_tile_error}") + return False + + return True