From f622d546f3a4fb937c1013411efaa5b96f6bc672 Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Thu, 27 Nov 2025 15:49:57 -0700 Subject: [PATCH] Tile engine for streamk (#3157) * [CK TILE STREAMK] Introduce initial support for tile engine in streamk GEMM. - This commit lays the groundwork for integrating the tile engine into streamk GEMM. It focuses on creating benchmark executables for streamk GEMM. - Additional scripts like test_benchmark.sh and gemm_benchmark.py will be added once the streamk implementation reaches stability. * [CK TILE STREAMK] Enable CI to execute tile engine benchmarks for StreamK GEMM * [CK TILE STREAMK] Refactor: Extract common utility functions. * [CK TILE STREAMK] Revise tile engine of streamk to align with the updated implementation * Add pre-commit * [CK TILE STREAMK] Add 'dp_persistent' and 'reduction_strategy' in output of CK TILE STREAMK * [CK TILE STREAMK] Fix a bug about value of 'dp_persistent' of CK TILE STREAMK * [CK TILE STREAMK] Update Jenkinsfile * [CK TILE Engine] Update StreamK tile engine help message Remove default value messages as they are automatically printed * [CK TILE Engine] Update StreamK tile engine - Remove namespace reboot * [CK TILE Engine] Update StreamK tile engine - Fix merge error [ROCm/composable_kernel commit: 30727c48fcdf2178f013cbb843db563abd77d09c] --- Jenkinsfile | 8 +- .../40_streamk_gemm/run_gemm_example.inc | 2 +- .../40_streamk_gemm/streamk_gemm_basic.cpp | 4 +- .../ops/gemm/kernel/streamk_gemm_kernel.hpp | 13 +- .../gemm_streamk/test_gemm_streamk_util.hpp | 3 +- tile_engine/include/utility/validation.hpp | 50 + tile_engine/ops/CMakeLists.txt | 3 +- tile_engine/ops/gemm_streamk/CMakeLists.txt | 295 ++++++ .../gemm_streamk/configs/default_config.json | 105 ++ .../gemm_streamk/gemm_streamk_benchmark.hpp | 201 ++++ .../gemm_streamk_benchmark_single.cpp | 169 ++++ .../ops/gemm_streamk/gemm_streamk_common.hpp | 145 +++ .../gemm_streamk_instance_builder.py | 905 ++++++++++++++++++ .../gemm_streamk/gemm_streamk_profiler.hpp | 296 ++++++ .../gemm_streamk_validation_utils.py | 350 +++++++ 15 files changed, 2530 insertions(+), 19 deletions(-) create mode 100644 tile_engine/include/utility/validation.hpp create mode 100644 tile_engine/ops/gemm_streamk/CMakeLists.txt create mode 100644 tile_engine/ops/gemm_streamk/configs/default_config.json create mode 100644 tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp create mode 100644 tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp create mode 100644 tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp create mode 100644 tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py create mode 100644 tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp create mode 100644 tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py diff --git a/Jenkinsfile b/Jenkinsfile index c79b8f18e1..a2e5b3d20b 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1615,11 +1615,13 @@ pipeline { -D GPU_TARGETS="gfx90a" \ -D GEMM_DATATYPE="fp8;fp16" \ -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ + -D GEMM_STREAMK_DATATYPE="fp8;fp16" \ + -D GEMM_STREAMK_LAYOUT="rcr" \ -D GEMM_MULTI_D_DATATYPE="fp16" \ -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \ -D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \ - ninja -j64 benchmark_gemm_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \ + ninja -j64 benchmark_gemm_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all && \ python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ python3 ../tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """ @@ -1644,11 +1646,13 @@ pipeline { -D GPU_TARGETS="gfx942" \ -D GEMM_DATATYPE="fp8;fp16" \ -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ + -D GEMM_STREAMK_DATATYPE="fp8;fp16" \ + -D GEMM_STREAMK_LAYOUT="rcr" \ -D GEMM_MULTI_D_DATATYPE="fp16" \ -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \ -D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \ - ninja -j64 benchmark_gemm_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \ + ninja -j64 benchmark_gemm_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all && \ python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ python3 ../tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """ diff --git a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc index 206b9c37fc..ebb5140e50 100644 --- a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc +++ b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc @@ -86,7 +86,7 @@ invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, std::tuple ave_time_and_batch; - if(args.reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic) + if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic) { ave_time_and_batch = gemm gemm(const ck_tile::StreamKHostArgs& args, } auto reset_data_buffers = [&]() { - if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) + if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) { // Clear the output C tensor results after each repetition of the kernel hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); } - else if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) + else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) { // Reset sk flags to zero before each repetition of the kernel workspace_data.SetZero(); diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp index f941ae7597..91f1358321 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp @@ -28,8 +28,7 @@ struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<> index_t K_, index_t stride_A_, index_t stride_B_, - index_t stride_C_, - StreamKReductionStrategy reduction_strategy_) + index_t stride_C_) : UniversalGemmHostArgs<>({a_ptr_}, {b_ptr_}, {/*ds_ptr*/}, @@ -41,12 +40,9 @@ struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<> {stride_A_}, {stride_B_}, {/*stride_Ds_*/}, - stride_C_), - reduction_strategy{reduction_strategy_} + stride_C_) { } - - ck_tile::StreamKReductionStrategy reduction_strategy; }; /** @@ -133,7 +129,6 @@ struct StreamKKernel host_args.stride_Ds, host_args.stride_E, host_args.k_batch}, - reduction_strategy{host_args.reduction_strategy}, // The workspace pointer is set to nullptr because we must first // instantiate the TilePartitioner to get the necessary size workspace_ptr{nullptr}, @@ -141,10 +136,6 @@ struct StreamKKernel { } - /** - * @brief The strategy used by work groups to compute final results in C tensor. - */ - StreamKReductionStrategy reduction_strategy; /** * @brief A pointer to a buffer in device memory for accumulating partial via reduction * strategy. diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp index 72b4c52831..213702551a 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp @@ -250,8 +250,7 @@ class TestCkTileStreamK : public ::testing::Test K, stride_A, stride_B, - stride_C, - reduction_strategy}; + stride_C}; ck_tile::index_t num_accumulations_per_tile = invoke_streamk( diff --git a/tile_engine/include/utility/validation.hpp b/tile_engine/include/utility/validation.hpp new file mode 100644 index 0000000000..dc57e6cc6a --- /dev/null +++ b/tile_engine/include/utility/validation.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c), Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +/// @brief Function to compare the results of the device and host computations +bool compare(std::string instanceName, + ck_tile::index_t K, + ck_tile::index_t kbatch, + ck_tile::HostTensor& c_m_n_dev_result, + ck_tile::HostTensor& c_m_n_host_result) +{ + const float max_accumulated_value = + *std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_result, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "For " << instanceName << " Relative error threshold is " + << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is " + << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl; + + return pass; +} diff --git a/tile_engine/ops/CMakeLists.txt b/tile_engine/ops/CMakeLists.txt index db100553f3..405600188c 100644 --- a/tile_engine/ops/CMakeLists.txt +++ b/tile_engine/ops/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(gemm) add_subdirectory(gemm_multi_d) -add_subdirectory(gemm_preshuffle) \ No newline at end of file +add_subdirectory(gemm_preshuffle) +add_subdirectory(gemm_streamk) diff --git a/tile_engine/ops/gemm_streamk/CMakeLists.txt b/tile_engine/ops/gemm_streamk/CMakeLists.txt new file mode 100644 index 0000000000..acfd78edc5 --- /dev/null +++ b/tile_engine/ops/gemm_streamk/CMakeLists.txt @@ -0,0 +1,295 @@ +set(GEMM_STREAMK_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)") +set(GEMM_STREAMK_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)") +set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") +option(ENABLE_CCACHE_GEMM "Enable ccache for GEMM ops compilation" OFF) + +# Store the directory path for use in functions +set(GEMM_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) + +# Function to create individual GEMM targets +function(create_individual_gemm_target datatype layout trait tile_config config_json) + # Use the parent scope GEMM_GPU_TARGETS_INDIVIDUAL variable + if(NOT GEMM_GPU_TARGETS_INDIVIDUAL) + message(WARNING "Skipping individual GEMM target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets") + return() + endif() + + # Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k + # First split by underscore to get three groups + string(REPLACE "_" ";" config_groups ${tile_config}) + list(GET config_groups 0 tile_dims) # e.g., 256x256x32 + list(GET config_groups 1 warp_dims) # e.g., 4x1x1 + list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16 + + # Parse tile dimensions + string(REPLACE "x" ";" tile_parts ${tile_dims}) + list(GET tile_parts 0 tile_m) + list(GET tile_parts 1 tile_n) + list(GET tile_parts 2 tile_k) + + # Parse warp dimensions + string(REPLACE "x" ";" warp_parts ${warp_dims}) + list(GET warp_parts 0 warp_m) + list(GET warp_parts 1 warp_n) + list(GET warp_parts 2 warp_k) + + # Parse warp tile dimensions + string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims}) + list(GET warp_tile_parts 0 warp_tile_m) + list(GET warp_tile_parts 1 warp_tile_n) + list(GET warp_tile_parts 2 warp_tile_k) + + set(target_name "benchmark_gemm_streamk_${datatype}_${layout}_${trait}_${tile_config}") + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") + + # Generate the single instance header for this kernel + set(instance_header "${working_path}/gemm_streamk_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") + + # Add custom command to generate the header file at build time + add_custom_command( + OUTPUT ${instance_header} + COMMAND ${Python3_EXECUTABLE} ${GEMM_SOURCE_DIR}/gemm_streamk_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${config_json} + --gen_single + --kernel_name "${datatype}_${layout}_${trait}_${tile_config}" + --tile_config "${tile_config}" + --trait_combo "${trait}" + DEPENDS ${GEMM_SOURCE_DIR}/gemm_streamk_instance_builder.py ${config_json} + COMMENT "Generating ${instance_header}" + ) + + # Create the executable + add_executable(${target_name} + ${GEMM_SOURCE_DIR}/gemm_streamk_benchmark_single.cpp + ${instance_header} + ) + + # Set GPU architectures + set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS_INDIVIDUAL}) + + # Set compile definitions + target_compile_definitions(${target_name} PRIVATE + GEMM_SINGLE_INSTANCE_HPP="${instance_header}" + ) + + # Include directories + target_include_directories(${target_name} PRIVATE + ${GEMM_SOURCE_DIR} + ${working_path} + ) + + # Compile options + target_compile_options(${target_name} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + -include ${instance_header} + ) + + # Add to collection targets + add_dependencies(benchmark_gemm_streamk_all ${target_name}) + add_dependencies(benchmark_gemm_streamk_${datatype} ${target_name}) + add_dependencies(benchmark_gemm_streamk_${layout} ${target_name}) + add_dependencies(benchmark_gemm_streamk_${datatype}_${layout} ${target_name}) + + # Add to trait-specific targets + string(REPLACE "_" ";" trait_parts ${trait}) + list(GET trait_parts 0 pipeline) + list(GET trait_parts 1 epilogue) + list(GET trait_parts 2 scheduler) + + add_dependencies(benchmark_gemm_streamk_${pipeline} ${target_name}) + add_dependencies(benchmark_gemm_streamk_${epilogue} ${target_name}) + add_dependencies(benchmark_gemm_streamk_${scheduler} ${target_name}) +endfunction() + +# Function to build individual GEMM targets +function(build_individual_gemm_targets datatype layout) + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") + + # Choose config file + # Priority order: + # 1. Environment variable GEMM_CONFIG_FILE + # 2. CMake variable GEMM_CONFIG_FILE + # 3. Default based on layout + + # Check environment variable first + if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "") + set(config_filename "$ENV{GEMM_CONFIG_FILE}") + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}") + message(STATUS " Using config from environment variable: ${config_filename}") + elseif(NOT "${GEMM_CONFIG_FILE}" STREQUAL "") + # Use CMake variable if set + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_CONFIG_FILE}") + message(STATUS " Using custom config: ${GEMM_CONFIG_FILE}") + else() + # Use default config for all layouts + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") + message(STATUS " Using default config for layout ${layout}") + endif() + + # Check if config file exists + if(NOT EXISTS ${json_blob}) + message(FATAL_ERROR "Config file not found: ${json_blob}") + endif() + + # Determine number of workers for parallel generation + if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL}) + set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL}) + else() + # Use processor count but limit to avoid memory issues + cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES) + math(EXPR num_workers "${num_cores}") + if(num_workers GREATER 8) + set(num_workers 8) + endif() + endif() + + # Generate individual kernel files using parallel version + message(STATUS "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...") + message(STATUS " Working path: ${working_path}") + message(STATUS " Config file: ${json_blob}") + message(STATUS " Python executable: ${Python3_EXECUTABLE}") + message(STATUS " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_streamk_instance_builder.py") + + # Create working directory first + file(MAKE_DIRECTORY ${working_path}) + + # First, just list the kernels (fast operation) + message(STATUS " Listing kernel configurations...") + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_streamk_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${json_blob} + --list_kernels + WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} + RESULT_VARIABLE ret + OUTPUT_VARIABLE list_output + ERROR_VARIABLE list_error + ) + + if(NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}") + endif() + + # Read kernel count + if(EXISTS ${working_path}/gemm_kernel_count.txt) + file(READ ${working_path}/gemm_kernel_count.txt kernel_count) + string(STRIP "${kernel_count}" kernel_count) + message(STATUS " Found ${kernel_count} kernel configurations") + else() + message(FATAL_ERROR "Kernel count file not found") + endif() + + # Read kernel list and create targets + if(EXISTS ${working_path}/gemm_kernel_list.txt) + file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) + foreach(line IN LISTS kernel_lines) + # Parse line: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) + + # Create individual target + create_individual_gemm_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}") + endforeach() + else() + message(FATAL_ERROR "Kernel list file not found") + endif() +endfunction() + +# Main build logic - Only individual builds supported +message(STATUS "=== Starting Tile Engine StreamK GEMM Configuration ===") +message(STATUS "GEMM_STREAMK_DATATYPE: ${GEMM_STREAMK_DATATYPE}") +message(STATUS "GEMM_STREAMK_LAYOUT: ${GEMM_STREAMK_LAYOUT}") +message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + +# Filter GPU targets to only gfx90a, gfx942 +set(GEMM_GPU_TARGETS_INDIVIDUAL "") +set(DESIRED_TARGETS "gfx90a;gfx942") # TODO: Add gfx950 when supported + +foreach(target IN LISTS SUPPORTED_GPU_TARGETS) + if(target IN_LIST DESIRED_TARGETS) + list(APPEND GEMM_GPU_TARGETS_INDIVIDUAL ${target}) + message(STATUS " Adding GPU target: ${target}") + endif() +endforeach() + +# Skip build if no matching targets found +if(NOT GEMM_GPU_TARGETS_INDIVIDUAL) + message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") +else() + message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}") + + # Enable parallel compilation optimizations + # Set up job pools for better parallel compilation control + set_property(GLOBAL PROPERTY JOB_POOLS + compile_heavy=4 # Limit heavy compilations to prevent OOM + compile_normal=16 # Allow more parallel normal compilations + ) + + # Enable compiler cache if available and explicitly requested + # Disabled by default due to permission issues in CI environments + if(ENABLE_CCACHE_GEMM) + find_program(CCACHE_PROGRAM ccache) + if(CCACHE_PROGRAM) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) + message(STATUS "Using ccache for faster compilation") + else() + message(WARNING "ccache requested but not found") + endif() + else() + message(STATUS "ccache disabled for GEMM ops (use -DENABLE_CCACHE_GEMM=ON to enable)") + endif() + + # Create master collection targets + add_custom_target(benchmark_gemm_streamk_all) + + # Create datatype collection targets + foreach(dt IN LISTS GEMM_STREAMK_DATATYPE) + add_custom_target(benchmark_gemm_streamk_${dt}) + endforeach() + + # Create layout collection targets + foreach(l IN LISTS GEMM_STREAMK_LAYOUT) + add_custom_target(benchmark_gemm_streamk_${l}) + endforeach() + + # Create combined collection targets + foreach(dt IN LISTS GEMM_STREAMK_DATATYPE) + foreach(l IN LISTS GEMM_STREAMK_LAYOUT) + add_custom_target(benchmark_gemm_streamk_${dt}_${l}) + endforeach() + endforeach() + + # Create trait-based collection targets + # These are common trait components used across all GEMM kernels + set(GEMM_PIPELINES "mem;compv3;compv4") + set(GEMM_EPILOGUES "default;cshuffle") + set(GEMM_SCHEDULERS "intrawave;interwave") + + foreach(pipeline IN LISTS GEMM_PIPELINES) + add_custom_target(benchmark_gemm_streamk_${pipeline}) + endforeach() + + foreach(epilogue IN LISTS GEMM_EPILOGUES) + add_custom_target(benchmark_gemm_streamk_${epilogue}) + endforeach() + + foreach(scheduler IN LISTS GEMM_SCHEDULERS) + add_custom_target(benchmark_gemm_streamk_${scheduler}) + endforeach() + + # Build individual targets for each datatype/layout combination + foreach(dt IN LISTS GEMM_STREAMK_DATATYPE) + foreach(l IN LISTS GEMM_STREAMK_LAYOUT) + build_individual_gemm_targets(${dt} ${l}) + endforeach() + endforeach() +endif() diff --git a/tile_engine/ops/gemm_streamk/configs/default_config.json b/tile_engine/ops/gemm_streamk/configs/default_config.json new file mode 100644 index 0000000000..f6b92feee3 --- /dev/null +++ b/tile_engine/ops/gemm_streamk/configs/default_config.json @@ -0,0 +1,105 @@ +{ + "problem": { + }, + "tile_config": { + "tile_m": { + "max": 256, + "min": 64, + "step": 64 + }, + "tile_n": { + "max": 256, + "min": 64, + "step": 64 + }, + "tile_k": { + "max": 256, + "min": 64, + "step": 64 + }, + "warp_m": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_n": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 4, + 16, + 32 + ] + }, + "warp_tile_n": { + "values": [ + 16, + 32, + 64 + ] + }, + "warp_tile_k": { + "values": [ + 8, + 16, + 32, + 64, + 128 + ] + } + }, + "trait_config": { + "pipeline": { + "values": [ + "compv3" + ] + }, + "scheduler": { + "values": [ + "intrawave" + ] + }, + "epilogue": { + "values": [ + "cshuffle" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + }, + "persistent": { + "values": [ + false, true + ] + }, + "reduction_strategy": { + "values": [ + "reduction", "atomic" + ] + } + } +} diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp new file mode 100644 index 0000000000..fa8a019be5 --- /dev/null +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "gemm_streamk_common.hpp" +#include "utility/validation.hpp" + +// Data types and Layouts are defined by the generated kernel headers +// No hardcoded type definitions here to avoid conflicts + +enum class Metric +{ + LATENCY = 0, + TFLOPS = 1, + BANDWIDTH = 2 +}; + +inline constexpr auto get_metric_name(Metric m) +{ + switch(m) + { + case Metric::LATENCY: return "latency"; + case Metric::TFLOPS: return "tflops"; + case Metric::BANDWIDTH: return "bandwidth"; + default: throw std::invalid_argument("Unsupported metric type"); + } +} + +struct GemmProblem +{ + int split_k_; + int m_, n_, k_; + int stride_a_, stride_b_, stride_c_; + + std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_; + std::string layout_a_, layout_b_, layout_c_; + + bool structured_sparsity_; + + friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem) + { + os << "{\n" + << " \"split_k\":" << problem.split_k_ << ",\n" + << " \"m\":" << problem.m_ << ",\n" + << " \"n\":" << problem.n_ << ",\n" + << " \"k\":" << problem.k_ << ",\n" + << " \"stride_a\":" << problem.stride_a_ << ",\n" + << " \"stride_b\":" << problem.stride_b_ << ",\n" + << " \"stride_c\":" << problem.stride_c_ << ",\n" + << " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n" + << " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n" + << " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n" + << " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n" + << " \"layout_a\":\"" << problem.layout_a_ << "\",\n" + << " \"layout_b\":\"" << problem.layout_b_ << "\",\n" + << " \"layout_c\":\"" << problem.layout_c_ << "\",\n" + << " \"structured_sparsity\":" << (problem.structured_sparsity_ ? "true" : "false") + << "\n" + << "}"; + return os; + } +}; + +struct PerformanceResult +{ + double latency_; + double tflops_; + double bandwidth_; + + static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m) + { + switch(m) + { + case Metric::LATENCY: return a.latency_ < b.latency_; + case Metric::TFLOPS: return a.tflops_ > b.tflops_; + case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_; + default: throw std::invalid_argument("Unsupported metric type"); + } + } + + friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result) + { + os << "{\n" + << " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_ + << ",\n" + << " \"tflops(TFlops)\": " << result.tflops_ << ",\n" + << " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n" + << "}"; + return os; + } +}; + +struct KernelInstance +{ + std::string name_; + std::string dp_persistent_; + std::string reduction_strategy_; + GemmProblem problem_; + PerformanceResult perf_result_; + + static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m) + { + return PerformanceResult::compare(a.perf_result_, b.perf_result_, m); + } + + friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj) + { + os << "{\n" + << " \"name\": \"" << obj.name_ << "\",\n" + << " \"dp_persistent\": \"" << obj.dp_persistent_ << "\",\n" + << " \"reduction_strategy\": \"" << obj.reduction_strategy_ << "\",\n" + << " \"problem\": " << obj.problem_ << ",\n" + << " \"perf_result\": " << obj.perf_result_ << "\n" + << "}"; + return os; + } +}; + +struct Setting +{ + int n_warmup_; + int n_repeat_; + bool is_gpu_timer_; + int verify_; + int init_method_; + bool log_; + std::string csv_filename_; + bool flush_cache_; + int rotating_count_; + bool json_output_; +}; + +inline std::string get_rocm_version() +{ + std::ifstream version_file("/opt/rocm/.info/version"); + if(version_file.is_open()) + { + std::string version; + std::getline(version_file, version); + return version; + } + return "Unknown"; +} + +/// @brief Function to get the kernel output with reference implementation on CPU/GPU +void gemm_host_reference(int verify, + ck_tile::HostTensor& a_m_k, + ck_tile::HostTensor& b_k_n, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C) +{ + if(verify == 1) + { + c_m_n_host_result.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_result); + } + else if(verify == 2) + { + if constexpr(std::is_same_v) + { + // Restore input for B for gpu reference + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes()); + c_m_n_host_result.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data()); + } +} diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp new file mode 100644 index 0000000000..13cadcd55a --- /dev/null +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "gemm_streamk_profiler.hpp" +#include "gemm_streamk_common.hpp" + +// The kernel header is included via the compile command line with -include flag +// It defines SelectedKernel struct and KERNEL_NAME +// DataTypeTraits are now defined in gemm_streamk_common.hpp + +// Create argument parser +inline auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "The value for m dimension.") + .insert("n", "4096", "The value for n dimension.") + .insert("k", "2048", "The value for k dimension.") + .insert("stride_a", "0", "The stride value for tensor A.") + .insert("stride_b", "0", "The stride value for tensor B.") + .insert("stride_c", "0", "The stride value for tensor C.") + .insert("split_k", "1", "The split value for k dimension.") + .insert("verify", + "0", + "The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 " + "for validation on GPU.") + .insert("log", + "false", + "Whether output kernel instance information or not. Possible values are true or " + "false.") + .insert("warmup", "50", "The number of iterations before benchmark the kernel.") + .insert("repeat", "100", "The number of iterations to benchmark the kernel.") + .insert("timer", + "true", + "Whether if the timer is gpu timer or not. Possible values are false or true. " + "") + .insert("init", + "0", + "The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 " + "for constant(1).") + .insert("flush_cache", "true", "To flush cache, possible values are true or false.") + .insert("rotating_count", "1000", "number of iterations to rotate the cache. default is 5.") + .insert("metric", + "0", + "Metric with which to measure kernel performance. Set to 0 for latency, 1 for " + "tflops, or 2 for bandwidth.") + .insert("csv_filename", + "", + "The filename of benchmark result. Default is empty (no CSV output).") + .insert("structured_sparsity", + "false", + "Whether use sparsity kernel or not. Possible values are true or false.") + .insert( + "json_output", + "false", + "Whether to output results in JSON format only. Possible values are true or false."); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +void repeat_once_if_verify(Setting& setting) +{ + // The output buffer will be reset after each run, which means the gemm result will be + // accumulated in the output buffer. So limit the repeat to 1 if verify is true. + if(setting.verify_) + { + setting.n_repeat_ = 1; + setting.n_warmup_ = 0; + } +} + +void benchmark_gemm_single(const ck_tile::ArgParser& arg_parser) +{ + // Use DataTypeTraits to get the actual type names from the generated header + // The generated header defines ADataType, BDataType, AccDataType, CDataType + std::string dtype_a = DataTypeTraits::name; + std::string dtype_b = DataTypeTraits::name; + std::string dtype_acc = DataTypeTraits::name; + std::string dtype_c = DataTypeTraits::name; + + // Layout names from the layout types + std::string layout_a = ALayout::name; + std::string layout_b = BLayout::name; + std::string layout_c = CLayout::name; + + // Create GemmProblem struct + GemmProblem gemm_problem{arg_parser.get_int("split_k"), + arg_parser.get_int("m"), + arg_parser.get_int("n"), + arg_parser.get_int("k"), + arg_parser.get_int("stride_a"), + arg_parser.get_int("stride_b"), + arg_parser.get_int("stride_c"), + dtype_a, + dtype_b, + dtype_acc, + dtype_c, + layout_a, + layout_b, + layout_c, + arg_parser.get_bool("structured_sparsity")}; + + // Create Setting struct + Setting setting{arg_parser.get_int("warmup"), + arg_parser.get_int("repeat"), + arg_parser.get_bool("timer"), + arg_parser.get_int("verify"), + arg_parser.get_int("init"), + arg_parser.get_bool("log"), + arg_parser.get_str("csv_filename"), + arg_parser.get_bool("flush_cache"), + arg_parser.get_int("rotating_count"), + arg_parser.get_bool("json_output")}; + + repeat_once_if_verify(setting); + + // Get the profiler instance + auto& profiler = GemmProfiler::instance(setting); + + try + { + // Create a lambda that wraps the kernel launch + auto kernel_func = [](const ck_tile::StreamKHostArgs& args, + const ck_tile::stream_config& stream) { + return SelectedKernel::launch(args, stream); + }; + + // Benchmark the kernel + profiler.benchmark(gemm_problem, kernel_func); + + // Select best instance based on metric + profiler.select_best_instance(static_cast(arg_parser.get_int("metric"))); + } + catch(const std::exception& e) + { + std::cerr << "Benchmark failed: " << e.what() << std::endl; + } +} + +int main(int argc, char* argv[]) +{ + try + { + auto [result, parser] = create_args(argc, argv); + if(!result) + { + parser.print(); + return EXIT_FAILURE; + } + + benchmark_gemm_single(parser); + return EXIT_SUCCESS; + } + catch(const std::exception& e) + { + std::cerr << "Error: " << e.what() << "\n"; + return EXIT_FAILURE; + } +} diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp new file mode 100644 index 0000000000..179aeb7307 --- /dev/null +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/pk_int4.hpp" + +// DataTypeTraits for all supported types +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +// Helper function to determine if a layout is row-major +template +constexpr auto is_row_major(Layout) +{ + return ck_tile::bool_constant>{}; +} + +// Structure to hold kernel traits for dispatcher +struct KernelTraits +{ + std::string pipeline; // compv3, compv4, mem + std::string scheduler; // intrawave, interwave + std::string epilogue; // cshuffle, default + bool pad_m; + bool pad_n; + bool pad_k; + bool persistent; + + // Constructor with defaults + KernelTraits() + : pipeline("compv3"), + scheduler("intrawave"), + epilogue("cshuffle"), + pad_m(false), + pad_n(false), + pad_k(false), + persistent(false) + { + } +}; + +// Helper to extract traits from kernel name +inline KernelTraits extract_traits_from_name(const std::string& kernel_name) +{ + KernelTraits traits; + + // Extract pipeline + if(kernel_name.find("compv3") != std::string::npos) + { + traits.pipeline = "compv3"; + } + else if(kernel_name.find("compv4") != std::string::npos) + { + traits.pipeline = "compv4"; + } + else if(kernel_name.find("mem") != std::string::npos) + { + traits.pipeline = "mem"; + } + + // Extract scheduler + if(kernel_name.find("interwave") != std::string::npos) + { + traits.scheduler = "interwave"; + } + else + { + traits.scheduler = "intrawave"; + } + + // Extract epilogue + if(kernel_name.find("default") != std::string::npos && + kernel_name.find("default_") == std::string::npos) + { + traits.epilogue = "default"; + } + else + { + traits.epilogue = "cshuffle"; + } + + // Padding flags would need to be extracted from the kernel configuration + // For now, we'll leave them as false + + return traits; +} diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py new file mode 100644 index 0000000000..6aebc54564 --- /dev/null +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py @@ -0,0 +1,905 @@ +#!/usr/bin/env python + +import os +import json +import argparse +import itertools +import multiprocessing +import concurrent.futures +from pathlib import Path +import logging +from typing import Optional +from gemm_streamk_validation_utils import ( + is_tile_config_valid, + is_trait_combination_valid, +) + +logging.basicConfig(level=logging.INFO) + + +class GemmKernelBuilder: + def __init__(self, working_path, datatype, layout, config_json=None): + self.working_path = Path(working_path) + self.datatype = datatype + self.layout = layout + self.config_json = config_json + + # Create working directory if it doesn't exist + self.working_path.mkdir(parents=True, exist_ok=True) + + # Load configuration + if config_json and os.path.exists(config_json): + with open(config_json, "r") as f: + self.config = json.load(f) + else: + self.config = self._get_default_config() + + def _get_default_config(self): + """Return default configuration if no config file is provided""" + # Define base tile configurations that work for all layouts + base_fp16_configs = [ + { + "tile_m": 256, + "tile_n": 256, + "tile_k": 32, + "warp_m": 2, + "warp_n": 2, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 32, + }, + { + "tile_m": 256, + "tile_n": 128, + "tile_k": 32, + "warp_m": 2, + "warp_n": 2, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 16, + }, + ] + + base_fp8_configs = [ + { + "tile_m": 256, + "tile_n": 256, + "tile_k": 32, + "warp_m": 4, + "warp_n": 1, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 32, + }, + { + "tile_m": 256, + "tile_n": 128, + "tile_k": 32, + "warp_m": 1, + "warp_n": 4, + "warp_k": 1, + "warp_tile_m": 16, + "warp_tile_n": 16, + "warp_tile_k": 32, + }, + ] + + # Create configurations for all supported layouts + all_layouts = ["rcr", "rrr", "ccr", "crr"] + tile_configs = {} + + for datatype, base_configs in [ + ("fp16", base_fp16_configs), + ("fp8", base_fp8_configs), + ]: + tile_configs[datatype] = {} + for layout in all_layouts: + tile_configs[datatype][layout] = base_configs + + return { + "tile_configs": tile_configs, + "traits": { + "pipelines": ["mem", "compv3", "compv4"], + "epilogues": ["default", "cshuffle"], + "schedulers": ["intrawave", "interwave"], + }, + "structured_sparsity": ["false"], + "padding": {"pad_m": ["false"], "pad_n": ["false"], "pad_k": ["false"]}, + "persistent": ["false"], + "reduction_strategy": ["reduction"], + } + + def _get_tile_configs(self, fast_mode=False): + """Get tile configurations for the current datatype and layout""" + if "tile_configs" in self.config: + # Old format + return ( + self.config["tile_configs"].get(self.datatype, {}).get(self.layout, []) + ) + elif "tile_config" in self.config: + # New format - generate combinations from individual parameter values + tile_config = self.config["tile_config"] + + # Get all possible values for each parameter + tile_m_values = tile_config.get("tile_m", {}).get("values", [256]) + tile_n_values = tile_config.get("tile_n", {}).get("values", [256]) + tile_k_values = tile_config.get("tile_k", {}).get("values", [32]) + warp_m_values = tile_config.get("warp_m", {}).get("values", [2]) + warp_n_values = tile_config.get("warp_n", {}).get("values", [2]) + warp_k_values = tile_config.get("warp_k", {}).get("values", [1]) + warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [32]) + warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [32]) + warp_tile_k_values = tile_config.get("warp_tile_k", {}).get("values", [32]) + + # Generate all combinations + configs = [] + for tile_m in tile_m_values: + for tile_n in tile_n_values: + for tile_k in tile_k_values: + for warp_m in warp_m_values: + for warp_n in warp_n_values: + for warp_k in warp_k_values: + for warp_tile_m in warp_tile_m_values: + for warp_tile_n in warp_tile_n_values: + for warp_tile_k in warp_tile_k_values: + # Validate configuration + if self._validate_tile_config( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + fast_mode=fast_mode, + ): + configs.append( + { + "tile_m": tile_m, + "tile_n": tile_n, + "tile_k": tile_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "warp_tile_m": warp_tile_m, + "warp_tile_n": warp_tile_n, + "warp_tile_k": warp_tile_k, + } + ) + return configs + else: + # Fallback to default + return [] + + def _validate_tile_config( + self, + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + pipeline="mem", # Default pipeline for validation + fast_mode=False, # Add fast mode option + ): + """Validate that tile configuration is reasonable""" + if fast_mode: + # Fast validation for listing - only basic sanity checks + if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: + return False + if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: + return False + if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: + return False + + # Basic divisibility check + if tile_m % (warp_m * warp_tile_m) != 0: + return False + if tile_n % (warp_n * warp_tile_n) != 0: + return False + if tile_k % (warp_k * warp_tile_k) != 0: + return False + + return True + else: + # Full validation for generation + # Determine data types for validation + a_datatype = self.datatype + b_datatype = self.datatype + c_datatype = self.datatype + + # Special handling for certain data types + if self.datatype in ["fp8", "bf8"]: + c_datatype = "fp16" + + # Use the comprehensive validation function + return is_tile_config_valid( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + pipeline, + ) + + def _generate_trait_combinations(self): + """Generate all combinations of traits""" + if "trait_config" in self.config: + # New format + trait_config = self.config["trait_config"] + + pipelines = trait_config.get("pipeline", {}).get("values", ["mem"]) + epilogues = trait_config.get("epilogue", {}).get("values", ["default"]) + schedulers = trait_config.get("scheduler", {}).get("values", ["intrawave"]) + pad_m_values = trait_config.get("pad_m", {}).get("values", [False]) + pad_n_values = trait_config.get("pad_n", {}).get("values", [False]) + pad_k_values = trait_config.get("pad_k", {}).get("values", [False]) + persistent_values = trait_config.get("persistent", {}).get( + "values", [False] + ) + reduction_strategy_value = trait_config.get("reduction_strategy", {}).get( + "values", ["reduction"] + ) + + all_combinations = list( + itertools.product( + pipelines, + epilogues, + schedulers, + reduction_strategy_value, + pad_m_values, + pad_n_values, + pad_k_values, + persistent_values, + ) + ) + + # Filter out unsupported trait combinations + combinations = [] + for combo in all_combinations: + pipeline, epilogue, scheduler, reduction_strategy = combo[:4] + if is_trait_combination_valid( + pipeline, epilogue, scheduler, reduction_strategy + ): + combinations.append(combo) + else: + logging.debug( + f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}-{reduction_strategy}" + ) + else: + # Fallback to minimal default + combinations = [ + ( + "compv3", + "cshuffle", + "intrawave", + "reduction_strategy", + False, + False, + False, + False, + ) + ] + + return combinations + + def _get_dtype_string(self): + """Get C++ type string for datatype""" + dtype_map = { + "fp16": "ck_tile::fp16_t", + "fp8": "ck_tile::fp8_t", + "bf16": "ck_tile::bf16_t", + "fp32": "float", + "fp64": "double", + } + return dtype_map.get(self.datatype, "float") + + _LAYOUT_MAP = { + "r": "ck_tile::tensor_layout::gemm::RowMajor", + "c": "ck_tile::tensor_layout::gemm::ColumnMajor", + } + + def _get_abc_layouts(self, layout_code: Optional[str] = None): + """ + Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. + If layout_code is None, use self.layout. + """ + if layout_code is None: + # fall back to the instance field + layout_code = getattr(self, "layout", "") + + code = str(layout_code).strip().lower() + + if len(code) != 3 or any(ch not in self._LAYOUT_MAP for ch in code): + raise ValueError( + f"Invalid layout '{layout_code}'. " + "Use a 3-letter code with 'r'/'c' (e.g., rcr, ccr, crr, rrr)." + ) + + a_layout = self._LAYOUT_MAP[code[0]] + b_layout = self._LAYOUT_MAP[code[1]] + c_layout = self._LAYOUT_MAP[code[2]] + return a_layout, b_layout, c_layout + + def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True): + """Generate a single kernel instance""" + ( + pipeline, + epilogue, + scheduler, + reduction_strategy, + pad_m, + pad_n, + pad_k, + persistent, + ) = trait_combo + + # Create kernel name with proper boolean capitalization + kernel_name = f"{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{reduction_strategy}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" + + # Create tile configuration string + tile_str = ( + f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + ) + tile_str += ( + f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + ) + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + kernel_name += f"_{tile_str}" + + # Map pipeline names to the correct pipeline implementation + pipeline_impl_map = { + "mem": "ck_tile::GemmPipelineAgBgCrMem", + "compv3": "ck_tile::GemmPipelineAgBgCrCompV3", + "compv4": "ck_tile::GemmPipelineAgBgCrCompV4", + } + + reduction_strategy_map = { + "atomic": "ck_tile::StreamKReductionStrategy::Atomic", + "reduction": "ck_tile::StreamKReductionStrategy::Reduction", + } + + # Determine accumulator type based on datatype + acc_type = "float" + if self.datatype in ["int8", "int4"]: + acc_type = "ck_tile::int32_t" + + # Determine output type + c_type = self._get_dtype_string() + if self.datatype in ["fp8", "bf8"]: + c_type = "ck_tile::fp16_t" + + # Determine layouts based on self.layout + a_layout, b_layout, c_layout = self._get_abc_layouts() + + # Generate kernel instance code using the correct API + pragma_line = "#pragma once\n" if is_header else "" + instance_code = f"""// Generated kernel instance for {kernel_name} +{pragma_line} +#include +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" + +using ADataType = {self._get_dtype_string()}; +using BDataType = {self._get_dtype_string()}; +using AccDataType = {acc_type}; +using CDataType = {c_type}; + +using ALayout = {a_layout}; +using BLayout = {b_layout}; +using CLayout = {c_layout}; + +// Kernel name for display +constexpr const char* KERNEL_NAME = "{kernel_name}"; + +// Wrapper for simplified launch interface +struct SelectedKernel {{ + // Tile configuration + static constexpr ck_tile::index_t BlockSize = 256; + static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]}; + static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]}; + static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]}; + static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]}; + static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]}; + static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]}; + static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]}; + static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]}; + static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]}; + + // Traits + static constexpr bool kPadM = {"true" if pad_m == "true" else "false"}; + static constexpr bool kPadN = {"true" if pad_n == "true" else "false"}; + static constexpr bool kPadK = {"true" if pad_k == "true" else "false"}; + static constexpr bool Preshuffle = false; + + static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"}; + static constexpr int kBlockPerCu = 1; + static constexpr bool StructuredSparsity = false; + static constexpr bool NumWaveGroup = 1; + + static constexpr bool TransposeC = false; + static constexpr bool UsePersistentKernel = {"true" if str(persistent).lower() == "true" else "false"}; + static constexpr bool UseStructuredSparsity = false; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr ck_tile::StreamKReductionStrategy reduction_strategy = {reduction_strategy_map.get(reduction_strategy, "ck_tile::StreamKReductionStrategy::Reduction")}; + + // Tile shape + using TileShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + // Tile partitioner + using TilePartitioner = ck_tile::StreamKTilePartitioner; + + // Traits + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + // Pipeline problem + using GemmPipelineProblem = ck_tile::GemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + GemmUniversalTraits>; + + static float launch(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {{ + const auto Run = [&](const auto memory_operation_) {{ + constexpr auto memory_operation = memory_operation_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::GemmPipelineAgBgCrCompV3")}; + + // Epilogue + using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, // kM_ + TilePartitioner::NPerBlock, // kN_ + WarpPerBlock_M, // MWave_ + WarpPerBlock_N, // NWave_ + WarpTileM, // MPerXdl_ + WarpTileN, // NPerXdl_ + WarpTileK, // KPerXdl_ + TransposeC, // isCTransposed_ + memory_operation, // MemoryOperation_ + NumWaveGroups>; // kNumWaveGroups_ + + using GemmEpilogue = ck_tile::CShuffleEpilogue; + + // Kernel type + using GemmKernel = ck_tile::StreamKKernel; + + // Make kernel arguments + auto kargs = GemmKernel::MakeKernelArgs(args); + const auto workspace_size = GemmKernel::GetWorkSpaceSize(kargs); + ck_tile::DeviceMem workspace_data(workspace_size); + workspace_data.SetZero(); + kargs.workspace_ptr = workspace_data.GetDeviceBuffer(); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); + }} + + // Get grid and block sizes + const dim3 grids = GemmKernel::GridSize(kargs.tile_partitioner); + const dim3 blocks = GemmKernel::BlockSize(); + + if(stream.log_level_ > 0) {{ + std::cout << "Launching kernel with args: " << GemmKernel::GetName() << "\\n" + << "shape: " << TileShape::GetName() << "\\n" + << "problem: " << UniversalGemmProblem::GetName() << "\\n" + << "pipeline: " << GemmPipeline::GetName() << "\\n" + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + }} + + auto reset_data_buffers = [&]() {{ + if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic) + {{ + // Clear the output C tensor results after each repetition of the kernel + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_)); + }} + else if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction) + {{ + // Reset sk flags to zero before each repetition of the kernel + workspace_data.SetZero(); + }} + }}; + + + // Launch kernel + float ave_time = ck_tile::launch_kernel_time_mask( + stream, + reset_data_buffers, + ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + return ave_time; + + // ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile(); + // return std::make_tuple(ave_time, num_wgs_per_tile); + }}; + + + if constexpr(ck_tile::StreamKReductionStrategy::Atomic == reduction_strategy) + {{ + return Run(ck_tile::integral_constant{{}}); + }} + else // We are using ck_tile::StreamKReductionStrategy::Reduction + {{ + return Run(ck_tile::integral_constant{{}}); + }} + }} +}}; +""" + + return kernel_name, instance_code + + def generate_individual(self, num_workers=None): + """Generate individual kernel files for separate compilation with parallel processing""" + if num_workers is None: + num_workers = min( + multiprocessing.cpu_count(), 8 + ) # Limit to avoid memory issues + + tile_configs = self._get_tile_configs() + trait_combos = self._generate_trait_combinations() + + # Prepare work items for parallel processing + work_items = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + work_items.append( + ( + tile_config, + trait_combo, + self.working_path, + self.datatype, + self.layout, + ) + ) + + print( + f"Generating {len(work_items)} individual kernel files using {num_workers} workers..." + ) + print(f" Tile configs: {len(tile_configs)}") + print(f" Trait combinations: {len(trait_combos)}") + print(f" Total kernels: {len(work_items)}") + + # Show first few work items for debugging + if work_items: + print(" First work item example:") + tile_config, trait_combo = work_items[0][:2] + print(f" Tile config: {tile_config}") + print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits + + # Process work items in parallel + kernel_list = [] + completed = 0 + + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_workers + ) as executor: + # Submit all work items + print(f" Submitting {len(work_items)} tasks to executor...") + future_to_item = { + executor.submit(_generate_single_kernel_individual, item): item + for item in work_items + } + print(" All tasks submitted, waiting for completion...") + + # Collect results with progress reporting + for future in concurrent.futures.as_completed(future_to_item): + completed += 1 + if completed % 100 == 0 or completed == len(work_items): + print( + f" Progress: {completed}/{len(work_items)} kernels generated" + ) + + try: + result = future.result() + if result: + kernel_list.append(result) + except Exception as exc: + item = future_to_item[future] + print(f"Kernel generation failed for {item}: {exc}") + + # Sort kernel list for consistent ordering + kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name + + # Generate CMake include file for individual targets + self._generate_cmake_individual_targets(kernel_list) + + print( + f"Generated {len(kernel_list)} individual kernel files in {self.working_path}" + ) + + def _generate_cmake_individual_targets(self, kernel_list): + """Generate CMake include file that creates individual targets""" + cmake_code = f"""# Generated CMake file for individual GEMM targets +# Datatype: {self.datatype}, Layout: {self.layout} + +""" + + for kernel_name, trait_combo, tile_config in kernel_list: + pipeline, epilogue, scheduler = trait_combo[:3] + + # Format tile config for CMake function + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + trait_str = f"{pipeline}_{epilogue}_{scheduler}_" + "_".join( + str(x) for x in trait_combo[3:] + ) + + cmake_code += f'create_individual_gemm_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n' + + # Write CMake include file + with open(self.working_path / "gemm_individual_targets.cmake", "w") as f: + f.write(cmake_code) + + def write_kernel_list(self): + """Write kernel list to file for CMake to read (with comprehensive validation)""" + # Get configurations using comprehensive validation + tile_configs = self._get_tile_configs(fast_mode=False) + trait_combos = self._generate_trait_combinations() + + kernel_list = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + ( + pipeline, + epilogue, + scheduler, + pad_m, + pad_n, + pad_k, + persistent, + reduction_strategy, + ) = trait_combo + + # Create kernel name with proper boolean capitalization + kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}_{reduction_strategy}" + + # Create tile configuration string + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + kernel_name += f"_{tile_str}" + + kernel_list.append( + { + "name": kernel_name, + "tile_config": tile_config, + "trait_combo": trait_combo, + } + ) + + # Write kernel count + with open(self.working_path / "gemm_kernel_count.txt", "w") as f: + f.write(str(len(kernel_list))) + + # Write kernel list + with open(self.working_path / "gemm_kernel_list.txt", "w") as f: + for kernel in kernel_list: + # Format: kernel_name|tile_config|trait_combo + tile_config = kernel["tile_config"] + trait_combo = kernel["trait_combo"] + + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + trait_str = ( + f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_" + + "_".join(str(x) for x in trait_combo[3:]) + ) + + f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n") + + print(f"Listed {len(kernel_list)} kernel configurations") + + def run(self, num_workers=None): + """Run the builder to generate individual kernel files""" + # Generate individual kernel files + self.generate_individual(num_workers) + + +def _generate_single_kernel_individual(work_item): + """Worker function to generate a single individual kernel file""" + tile_config, trait_combo, working_path, datatype, layout = work_item + + # Create a temporary builder instance for this worker + builder = GemmKernelBuilder(working_path, datatype, layout) + + try: + kernel_name, instance_code = builder._generate_kernel_instance( + tile_config, trait_combo + ) + + # Create simplified filename without the "gemm_" prefix + # Remove "gemm_" from the beginning of kernel_name for the filename + simplified_name = kernel_name + if simplified_name.startswith("gemm_"): + simplified_name = simplified_name[5:] # Remove "gemm_" prefix + + # Write individual header file + header_file = working_path / f"gemm_streamk_single_{simplified_name}.hpp" + with open(header_file, "w") as f: + f.write(instance_code) + + return (kernel_name, trait_combo, tile_config) + except Exception as e: + print(f"Error generating individual kernel: {e}") + return None + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM kernel instance builder with parallel support" + ) + parser.add_argument("--working_path", required=True, help="Working directory path") + parser.add_argument( + "--datatype", + required=True, + choices=["fp16", "fp8", "bf16", "fp32", "fp64"], + help="Data type", + ) + parser.add_argument( + "--layout", + required=True, + choices=["rcr", "rrr", "ccr", "crr"], + help="Matrix layout", + ) + parser.add_argument("--config_json", help="Configuration JSON file") + parser.add_argument( + "--num_workers", type=int, help="Number of parallel workers (default: auto)" + ) + parser.add_argument( + "--gen_individual", action="store_true", help="Generate individual kernel files" + ) + parser.add_argument( + "--gen_single", action="store_true", help="Generate a single kernel file" + ) + parser.add_argument("--kernel_name", help="Kernel name for single generation") + parser.add_argument( + "--tile_config", help="Tile configuration string for single generation" + ) + parser.add_argument( + "--trait_combo", help="Trait combination string for single generation" + ) + parser.add_argument( + "--list_kernels", + action="store_true", + help="List kernel configurations without generating files", + ) + + args = parser.parse_args() + + # Create builder + builder = GemmKernelBuilder( + args.working_path, args.datatype, args.layout, args.config_json + ) + + if args.list_kernels: + # Fast listing mode - just write kernel list without generating files + builder.write_kernel_list() + elif args.gen_single: + # Generate a single kernel file + if not args.kernel_name or not args.tile_config or not args.trait_combo: + parser.error( + "--gen_single requires --kernel_name, --tile_config, and --trait_combo" + ) + + # Parse tile config + tile_parts = args.tile_config.split("_") + tile_dims = tile_parts[0].split("x") + warp_dims = tile_parts[1].split("x") + warp_tile_dims = tile_parts[2].split("x") + + tile_config = { + "tile_m": int(tile_dims[0]), + "tile_n": int(tile_dims[1]), + "tile_k": int(tile_dims[2]), + "warp_m": int(warp_dims[0]), + "warp_n": int(warp_dims[1]), + "warp_k": int(warp_dims[2]), + "warp_tile_m": int(warp_tile_dims[0]), + "warp_tile_n": int(warp_tile_dims[1]), + "warp_tile_k": int(warp_tile_dims[2]), + } + + # Parse trait combo + trait_parts = args.trait_combo.split("_") + trait_combo = ( + trait_parts[0], # pipeline + trait_parts[1], # epilogue + trait_parts[2], # scheduler + trait_parts[3], # reduction_strategy + trait_parts[4] == "false", # pad_m + trait_parts[5] == "false", # pad_n + trait_parts[6] == "false", # pad_k + trait_parts[7], # persistent + ) + + # Generate the kernel + kernel_name, instance_code = builder._generate_kernel_instance( + tile_config, trait_combo + ) + + # Write the file + simplified_name = kernel_name + if simplified_name.startswith("gemm_"): + simplified_name = simplified_name[5:] + + header_file = ( + builder.working_path / f"gemm_streamk_single_{simplified_name}.hpp" + ) + with open(header_file, "w") as f: + f.write(instance_code) + + print(f"Generated {header_file}") + + elif args.gen_individual: + # Generate all individual kernel files + builder.run(args.num_workers) + else: + parser.error( + "Must specify one of: --list_kernels, --gen_individual, or --gen_single" + ) + + +if __name__ == "__main__": + main() diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp new file mode 100644 index 0000000000..256e0b9ca4 --- /dev/null +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp @@ -0,0 +1,296 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck_tile/host/device_prop.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "gemm_streamk_benchmark.hpp" + +class GemmProfiler +{ + public: + static GemmProfiler& instance(Setting setting) + { + static GemmProfiler instance{setting}; + return instance; + } + + // Overload for single kernel benchmarking + void benchmark(GemmProblem& gemm_problem, + std::function kernel_func) + { + // Create a vector with a single callable that returns both name and time + std::vector(ck_tile::StreamKHostArgs&, + const ck_tile::stream_config&)>> + callables; + + callables.push_back( + [kernel_func](ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) { + float time = kernel_func(args, stream); + return std::make_tuple(std::string(KERNEL_NAME), time); + }); + + benchmark(gemm_problem, callables); + } + + void benchmark(GemmProblem& gemm_problem, + std::vector( + ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>& callables) + { + const ALayout layout_a = ALayout{}; + const BLayout layout_b = BLayout{}; + const CLayout layout_c = CLayout{}; + + gemm_problem.stride_a_ = ck_tile::get_default_stride( + gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a)); + gemm_problem.stride_b_ = ck_tile::get_default_stride( + gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b)); + gemm_problem.stride_c_ = ck_tile::get_default_stride( + gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)); + + ck_tile::HostTensor a_m_k(ck_tile::host_tensor_descriptor( + gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a))); + ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( + gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b))); + ck_tile::HostTensor c_m_n_dev_result(ck_tile::host_tensor_descriptor( + gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); + + if(setting_.init_method_ == 0) + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + } + else if(setting_.init_method_ == 1) + { + ck_tile::FillMonotonicSeq{}(a_m_k); + ck_tile::FillMonotonicSeq{}(b_k_n); + } + else if(setting_.init_method_ == 2) + { + ck_tile::FillConstant{static_cast(1)}(a_m_k); + ck_tile::FillConstant{static_cast(1)}(b_k_n); + } + else + { + a_m_k.SetZero(); + b_k_n.SetZero(); + } + + if(gemm_problem.structured_sparsity_) + { + ck_tile::AdjustToStructuredSparsity{}(a_m_k); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor b_k_n_dev = b_k_n; + // permute_tensor_b(b_k_n_dev); + permute_vectors_i4x4_b(b_k_n_dev); + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + } + else + { + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + ck_tile::StreamKHostArgs gemm_args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + gemm_problem.m_, + gemm_problem.n_, + gemm_problem.k_, + gemm_problem.stride_a_, + gemm_problem.stride_b_, + gemm_problem.stride_c_}; + + ck_tile::HostTensor c_m_n_host_result(ck_tile::host_tensor_descriptor( + gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); + + if(setting_.verify_) + { + gemm_host_reference(setting_.verify_, + a_m_k, + b_k_n, + c_m_n_host_result, + a_m_k_dev_buf, + b_k_n_dev_buf, + gemm_problem.m_, + gemm_problem.n_, + gemm_problem.k_, + gemm_problem.stride_a_, + gemm_problem.stride_b_, + gemm_problem.stride_c_); + } + + for(auto& callable : callables) + { + auto kernel_run_result = callable(gemm_args, + ck_tile::stream_config{nullptr, + true, + setting_.log_, + setting_.n_warmup_, + setting_.n_repeat_, + setting_.is_gpu_timer_, + setting_.flush_cache_, + setting_.rotating_count_}); + process_result(gemm_problem, + c_m_n_dev_buf, + c_m_n_host_result, + c_m_n_dev_result, + kernel_run_result); + } + } + + void process_result(const GemmProblem& gemm_problem, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + const std::tuple& kernel_run_result) + { + auto [name, avg_time] = kernel_run_result; + auto dp_persistent = + SelectedKernel::UsePersistentKernel ? "PersistentKernel" : "NonPersistentKernel"; + auto reduction_strategy = + SelectedKernel::reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic + ? "Atomic" + : "Reduction"; + + KernelInstance kernel_instance{ + name, dp_persistent, reduction_strategy, gemm_problem, {-1.0f, -1.0f, -1.0f}}; + + // compute performance metric + std::size_t flop = std::size_t(2) * gemm_problem.m_ * gemm_problem.n_ * gemm_problem.k_; + std::size_t num_byte = sizeof(ADataType) * gemm_problem.m_ * gemm_problem.k_ + + sizeof(BDataType) * gemm_problem.n_ * gemm_problem.k_ + + sizeof(CDataType) * gemm_problem.m_ * gemm_problem.n_; + + // update + kernel_instance.perf_result_.latency_ = avg_time; + kernel_instance.perf_result_.tflops_ = static_cast(flop) / 1.E9 / avg_time; + kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time; + + if(setting_.log_ > 0 && !setting_.json_output_) + { + std::cout << kernel_instance << std::endl; + } + + // verify result + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool verified_correct = + !setting_.verify_ || + compare( + name, gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_host_result); + + if(verified_correct) + { + kernel_instances_.emplace_back(kernel_instance); + } + else + { + std::cout << "Verification failed, skip kernel: " << name << std::endl; + } + + // clear tensor + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + } + + KernelInstance select_best_instance(Metric metric) + { + if(kernel_instances_.empty()) + throw std::runtime_error("Empty instances"); + + auto kernel_instance = *std::max_element(kernel_instances_.begin(), + kernel_instances_.end(), + [metric](const auto& a, const auto& b) { + return PerformanceResult::compare( + b.perf_result_, a.perf_result_, metric); + }); + + if(setting_.json_output_) + { + // Output clean JSON only + std::cout << kernel_instance << std::endl; + } + else + { + std::cout << "**********************************" << std::endl; + std::cout << "According to given metrics: " << get_metric_name(metric) << "\n" + << "Current kernel performance is: " << kernel_instance << std::endl; + std::cout << "**********************************" << std::endl; + } + + if(!setting_.csv_filename_.empty()) + { + std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app); + + if(!file.is_open()) + { + std::cerr << "Warning: Failed to open CSV file for writing." << std::endl; + } + else + { + if(file.tellp() == 0) + { + file << "rocm_version,device_name," + << "split_k,m,n,k,stride_a,stride_b,stride_c," + << "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c," + << "structured_sparsity," << "dp_persistent," << "reduction_strategy," + << "name," << "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n"; + } + + const auto& problem = kernel_instance.problem_; + const auto& name = kernel_instance.name_; + const auto& dp_persistent = kernel_instance.dp_persistent_; + const auto& reduction_strategy = kernel_instance.reduction_strategy_; + const auto& perf = kernel_instance.perf_result_; + + file << get_rocm_version() << "," << ck_tile::get_device_name() << "," + << problem.split_k_ << "," << problem.m_ << "," << problem.n_ << "," + << problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << "," + << problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_ + << "," << problem.dtype_acc_ << "," << problem.dtype_c_ << "," + << problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_ + << "," << problem.structured_sparsity_ << "," << dp_persistent << "," + << reduction_strategy << "," << name << "," << std::fixed + << std::setprecision(4) << perf.latency_ << "," << std::fixed + << std::setprecision(4) << perf.tflops_ << "," << std::fixed + << std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric) + << "\n"; + + if(!file) + { + std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl; + } + } + } + + return kernel_instance; + } + + GemmProfiler(const GemmProfiler&) = delete; + GemmProfiler& operator=(const GemmProfiler&) = delete; + + private: + ~GemmProfiler() { kernel_instances_.clear(); } + GemmProfiler(Setting setting) : setting_(setting) {} + + Setting setting_; + + std::vector kernel_instances_; +}; diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py b/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py new file mode 100644 index 0000000000..2288d7752f --- /dev/null +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python +# SPDX-License-Identifier: MIT +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +""" +Validation utilities for GEMM kernel generation. +Extracted from tile_engine_develop for consistency. +""" + +import subprocess +import re +from functools import lru_cache +import logging +from typing import Tuple, List + +# Element size mapping for different data types +ELEMENT_SIZE_MAP = { + "fp16": 2, + "bf16": 2, + "int8": 1, + "fp8": 1, + "bf8": 1, + "int4": 0.5, + "int32": 4, + "fp32": 4, + "fp64": 8, +} + +# Supported warp tile combinations for different GPU architectures and data types +WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx90a": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], + }, + "gfx942": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], + }, + "gfx950": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "bf8_bf8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 64], + [16, 16, 32], + [16, 16, 128], + [32, 32, 64], + ], + }, +} + +# Unsupported trait combinations +TRAIT_UNSUPPORTED_COMBINATIONS = { + ("compv3", "cshuffle", "interwave", "reduction"), + ("compv3", "default", "interwave", "reduction"), + ("compv3", "cshuffle", "interwave", "atomic"), + ("compv3", "default", "interwave", "atomic"), +} + + +def element_size(data_type: str) -> float: + """Calculate the size (in bytes) of a single element for given data type.""" + data_type = data_type.lower() + if data_type not in ELEMENT_SIZE_MAP: + raise ValueError(f"Unsupported data type: {data_type}") + return ELEMENT_SIZE_MAP[data_type] + + +GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)") + + +@lru_cache(maxsize=1) +def get_gpu_name_by_id(gpu_id: int = 0) -> str: + """Retrieve GPU name (e.g. gfx90a) by device ID""" + try: + output = subprocess.check_output( + ["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5 + ) + if matches := GPU_NAME_PATTERN.finditer(output): + gpu_list = [m.group(1) for m in matches] + return gpu_list[gpu_id] if gpu_id < len(gpu_list) else "" + + return "" + + except subprocess.CalledProcessError as e: + logging.debug(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}") + except FileNotFoundError: + logging.debug("ROCm tools not installed (requires rocminfo)") + except subprocess.TimeoutExpired: + logging.debug("GPU query timeout (5s)") + except Exception as e: + logging.debug(f"GPU detection error: {str(e)}") + + return "" + + +def is_trait_combination_valid( + pipeline: str, epilogue: str, scheduler: str, reduction_strategy: str +) -> bool: + """Check if a trait combination is valid.""" + return ( + pipeline, + epilogue, + scheduler, + reduction_strategy, + ) not in TRAIT_UNSUPPORTED_COMBINATIONS + + +def validate_warp_configuration(warp_m: int, warp_n: int, warp_k: int) -> bool: + """Validate warp configuration.""" + return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + + +def validate_dimension_alignment( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, +) -> Tuple[bool, List[str]]: + """Check if tile dimensions are properly aligned with warp dimensions.""" + alignment_issues = [] + + if tile_m % (warp_m * warp_tile_m) != 0: + alignment_issues.append( + f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}" + ) + if tile_n % (warp_n * warp_tile_n) != 0: + alignment_issues.append( + f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}" + ) + if tile_k % (warp_k * warp_tile_k) != 0: + alignment_issues.append( + f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}" + ) + + return len(alignment_issues) == 0, alignment_issues + + +def validate_lds_capacity( + tile_m: int, + tile_n: int, + tile_k: int, + a_datatype: str, + b_datatype: str, + pipeline: str, +) -> Tuple[bool, str]: + """Validate LDS capacity requirements.""" + matrix_a_size = (tile_m * tile_k) * element_size(a_datatype) + matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) + total_tile_in_lds = matrix_a_size + matrix_b_size + + max_tile_size = 2**15 if pipeline == "compv4" else 2**16 + + if total_tile_in_lds > max_tile_size: + error_msg = ( + f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > " + f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n" + f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" + f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B" + ) + return False, error_msg + + return True, "" + + +def validate_warp_tile_combination( + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + gpu_name: str = None, +) -> Tuple[bool, str]: + """Validate warp tile combination against GPU-specific supported combinations.""" + if gpu_name is None: + gpu_name = get_gpu_name_by_id(0) + + # Construct the key for looking up supported combinations + warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}" + current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] + + # Check if we have GPU-specific combinations + gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {}) + if not gpu_warp_tile_combinations: + # If GPU not recognized, try to be permissive but log warning + logging.warning(f"No warp tile combinations found for GPU: {gpu_name}") + return True, "" + + # Check if we have combinations for this data type combination + allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, []) + if not allowed_combinations: + # For data type combinations not in the list, be permissive + logging.debug( + f"No warp tile combinations found for data types: {warp_tile_key}" + ) + return True, "" + + # Check if current combination is in the allowed list + if current_combination not in allowed_combinations: + error_msg = ( + f"Invalid warp tile combination: {current_combination} not in allowed list. " + f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}" + ) + return False, error_msg + + return True, "" + + +def is_tile_config_valid( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + pipeline: str, + trait_name: str = None, +) -> bool: + """ + Comprehensive tile configuration validation. + Returns True if configuration is valid, False otherwise. + """ + # Basic sanity checks + if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: + return False + if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: + return False + if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: + return False + + # Check that warp tiles fit within block tiles + if warp_m * warp_tile_m > tile_m: + return False + if warp_n * warp_tile_n > tile_n: + return False + if warp_k * warp_tile_k > tile_k: + return False + + # Validate warp configuration + if not validate_warp_configuration(warp_m, warp_n, warp_k): + logging.debug( + f"Invalid warp configuration: warp_m({warp_m}), warp_n({warp_n}), warp_k({warp_k})" + ) + return False + + # Validate dimension alignment + is_aligned, alignment_issues = validate_dimension_alignment( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + ) + if not is_aligned: + logging.debug( + f"Dimension alignment failed: {', '.join(alignment_issues)}. " + f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by " + f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + ) + return False + + # Validate LDS capacity + lds_valid, lds_error = validate_lds_capacity( + tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline + ) + if not lds_valid: + logging.debug(f"LDS validation failed: {lds_error}") + return False + + # Validate warp tile combination + warp_tile_valid, warp_tile_error = validate_warp_tile_combination( + warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype + ) + if not warp_tile_valid: + logging.debug(f"Warp tile validation failed: {warp_tile_error}") + return False + + return True