From facbc883fa162f3986e22e8b979b4bd110bd1685 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Fri, 12 Sep 2025 19:11:13 +0000 Subject: [PATCH] Merge commit 'f6ba94fb5c4d55e51ab002cfb7dec408e9a55fd8' into develop --- Jenkinsfile | 10 + .../arch/amd_buffer_addressing_builtins.hpp | 39 +- ...peline_ag_bg_cr_comp_v4_default_policy.hpp | 15 +- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 1 + tile_engine/ops/CMakeLists.txt | 3 +- .../ops/gemm_preshuffle/CMakeLists.txt | 296 +++++++ .../benchmark_gemm_preshuffle.hpp | 225 +++++ .../benchmark_gemm_preshuffle_single.cpp | 164 ++++ .../commons/validation_utils.py | 375 ++++++++ .../configs/default_config.json | 90 ++ .../configs/user_provided_config.json | 86 ++ .../gemm_preshuffle_benchmark.py | 684 ++++++++++++++ .../gemm_preshuffle_common.hpp | 213 +++++ .../gemm_preshuffle_instance_builder.py | 836 ++++++++++++++++++ .../gemm_preshuffle_profiler.hpp | 275 ++++++ 15 files changed, 3291 insertions(+), 21 deletions(-) create mode 100644 tile_engine/ops/gemm_preshuffle/CMakeLists.txt create mode 100644 tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle.hpp create mode 100644 tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle_single.cpp create mode 100644 tile_engine/ops/gemm_preshuffle/commons/validation_utils.py create mode 100644 tile_engine/ops/gemm_preshuffle/configs/default_config.json create mode 100644 tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json create mode 100755 tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py create mode 100644 tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp create mode 100644 tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py create mode 100644 tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp diff --git a/Jenkinsfile b/Jenkinsfile index 54a7da0ee3..87654328b6 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1353,10 +1353,15 @@ pipeline { -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ -D GEMM_MULTI_D_DATATYPE="fp16" \ -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ + -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8" \ + -D GEMM_PRESHUFFLE_LAYOUT="rcr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ ninja -j64 benchmark_gemm_all && \ python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \ --warmup 5 --repeat 5 --verbose --json results.json && \ + ninja -j64 benchmark_gemm_preshuffle_all && \ + python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" \ + --warmup 5 --repeat 5 --verbose --json results.json && \ ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \ ./bin/benchmark_gemm_multi_d_fp16_rrrr && \ ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \ @@ -1388,10 +1393,15 @@ pipeline { -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ -D GEMM_MULTI_D_DATATYPE="fp16" \ -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ + -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8" \ + -D GEMM_PRESHUFFLE_LAYOUT="rcr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ ninja -j64 benchmark_gemm_all && \ python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \ --warmup 5 --repeat 5 --verbose --json results.json && \ + ninja -j64 benchmark_gemm_preshuffle_all && \ + python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" \ + --warmup 5 --repeat 5 --verbose --json results.json && \ ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \ ./bin/benchmark_gemm_multi_d_fp16_rrrr && \ ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \ diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 4013b51479..5c7ffefc6a 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1335,8 +1335,10 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe static_assert( (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || @@ -1449,14 +1451,19 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe src_wave_addr_offset, static_cast(coherence))); } - else if constexpr(N == 8) + else { - // use fp32 load to mimic fp16 load - fp32x4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + // N >= 8: build from fp32x4 chunks + thread_buffer tmp; + static_for<0, (N / 8), 1>{}([&](auto i) { + constexpr index_t chunk = i; + tmp.template get_as()(i) = llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + (chunk * 4) * sizeof(float), + static_cast(coherence)); + }); return bit_cast(tmp); } } @@ -1486,13 +1493,19 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe src_wave_addr_offset, static_cast(coherence))); } - else if constexpr(N == 8) + else { - int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + // N >= 8: build from fp32x4 chunks + thread_buffer tmp; + static_for<0, (N / 8), 1>{}([&](auto i) { + constexpr index_t chunk = i; + tmp.template get_as()(i) = llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + (chunk * 4) * sizeof(float), + static_cast(coherence)); + }); return bit_cast(tmp); } } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp index a80ed57be5..3164b41cc7 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp @@ -20,17 +20,18 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { - // using AccDataType = float; using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - constexpr bool single_load_tr_length = - (DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType)) == - (WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size()); + constexpr index_t vector_size = + DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType); + constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size(); constexpr auto wg_attr_num_access = - ((is_a_load_tr || is_b_load_tr) && !single_load_tr_length) - ? WGAttrNumAccessEnum::Double - : WGAttrNumAccessEnum::Single; + !(is_a_load_tr || is_b_load_tr) ? WGAttrNumAccessEnum::Single + : vector_size == thread_elements ? WGAttrNumAccessEnum::Single + : vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double + : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad + : WGAttrNumAccessEnum::Invalid; using WarpGemm = WarpGemmDispatcher struct WarpGemmDispatcher; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; + //WMMA cases template struct WarpGemmDispatcher { using Type =WarpGemmWmma_f32_16x16x16_f8_f8; }; template struct WarpGemmDispatcher { using Type =WarpGemmWmma_f32_16x16x16_bf8_bf8; }; diff --git a/tile_engine/ops/CMakeLists.txt b/tile_engine/ops/CMakeLists.txt index 7d7002af1b..db100553f3 100644 --- a/tile_engine/ops/CMakeLists.txt +++ b/tile_engine/ops/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(gemm) -add_subdirectory(gemm_multi_d) \ No newline at end of file +add_subdirectory(gemm_multi_d) +add_subdirectory(gemm_preshuffle) \ No newline at end of file diff --git a/tile_engine/ops/gemm_preshuffle/CMakeLists.txt b/tile_engine/ops/gemm_preshuffle/CMakeLists.txt new file mode 100644 index 0000000000..2b8f5914f5 --- /dev/null +++ b/tile_engine/ops/gemm_preshuffle/CMakeLists.txt @@ -0,0 +1,296 @@ +set(GEMM_PRESHUFFLE_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM Preshuffle (semicolon-separated)") +set(GEMM_PRESHUFFLE_LAYOUT "rcr" CACHE STRING "List of layout for GEMM Preshuffle (semicolon-separated)") +set(GEMM_PRESHUFFLE_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") +option(ENABLE_CCACHE_GEMM_PRESHUFFLE "Enable ccache for GEMM Preshuffle ops compilation" OFF) + +# Store the directory path for use in functions +set(GEMM_PRESHUFFLE_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) + +# Function to create individual GEMM Preshuffle targets +function(create_individual_gemm_preshuffle_target datatype layout trait tile_config config_json) + # Use the parent scope GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL variable + if(NOT GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL) + message(WARNING "Skipping individual GEMM Preshuffle target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets") + return() + endif() + + # Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k + # First split by underscore to get three groups + string(REPLACE "_" ";" config_groups ${tile_config}) + list(GET config_groups 0 tile_dims) # e.g., 256x256x32 + list(GET config_groups 1 warp_dims) # e.g., 4x1x1 + list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16 + + # Parse tile dimensions + string(REPLACE "x" ";" tile_parts ${tile_dims}) + list(GET tile_parts 0 tile_m) + list(GET tile_parts 1 tile_n) + list(GET tile_parts 2 tile_k) + + # Parse warp dimensions + string(REPLACE "x" ";" warp_parts ${warp_dims}) + list(GET warp_parts 0 warp_m) + list(GET warp_parts 1 warp_n) + list(GET warp_parts 2 warp_k) + + # Parse warp tile dimensions + string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims}) + list(GET warp_tile_parts 0 warp_tile_m) + list(GET warp_tile_parts 1 warp_tile_n) + list(GET warp_tile_parts 2 warp_tile_k) + + set(target_name "benchmark_gemm_preshuffle_${datatype}_${layout}_${trait}_${tile_config}") + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") + + # Generate the single instance header for this kernel + set(instance_header "${working_path}/gemm_preshuffle_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") + + # Add custom command to generate the header file at build time + add_custom_command( + OUTPUT ${instance_header} + COMMAND ${Python3_EXECUTABLE} ${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${config_json} + --gen_single + --kernel_name "gemm_preshuffle_${datatype}_${layout}_${trait}_${tile_config}" + --tile_config "${tile_config}" + --trait_combo "${trait}" + DEPENDS ${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_instance_builder.py ${config_json} + COMMENT "Generating ${instance_header}" + ) + + # Create the executable + add_executable(${target_name} + ${GEMM_PRESHUFFLE_SOURCE_DIR}/benchmark_gemm_preshuffle_single.cpp + ${instance_header} + ) + + # Set GPU architectures + set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}) + + # Set compile definitions + target_compile_definitions(${target_name} PRIVATE + GEMM_PRESHUFFLE_SINGLE_INSTANCE_HPP="${instance_header}" + ) + + # Include directories + target_include_directories(${target_name} PRIVATE + ${GEMM_PRESHUFFLE_SOURCE_DIR} + ${working_path} + ) + + # Compile options + target_compile_options(${target_name} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + -include ${instance_header} + ) + + # Add to collection targets + add_dependencies(benchmark_gemm_preshuffle_all ${target_name}) + add_dependencies(benchmark_gemm_preshuffle_${datatype} ${target_name}) + add_dependencies(benchmark_gemm_preshuffle_${layout} ${target_name}) + add_dependencies(benchmark_gemm_preshuffle_${datatype}_${layout} ${target_name}) + + # Add to trait-specific targets + string(REPLACE "_" ";" trait_parts ${trait}) + list(GET trait_parts 0 pipeline) + list(GET trait_parts 1 epilogue) + list(GET trait_parts 2 scheduler) + + add_dependencies(benchmark_gemm_preshuffle_${pipeline}_pipeline ${target_name}) + add_dependencies(benchmark_gemm_preshuffle_${epilogue}_epilogue ${target_name}) + add_dependencies(benchmark_gemm_preshuffle_${scheduler}_scheduler ${target_name}) +endfunction() + +# Function to build individual GEMM Preshuffle targets +function(build_individual_gemm_preshuffle_targets datatype layout) + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") + + # Choose config file + # Priority order: + # 1. Environment variable GEMM_PRESHUFFLE_CONFIG_FILE + # 2. CMake variable GEMM_PRESHUFFLE_CONFIG_FILE + # 3. Default based on layout + + # Check environment variable first + if(DEFINED ENV{GEMM_PRESHUFFLE_CONFIG_FILE} AND NOT "$ENV{GEMM_PRESHUFFLE_CONFIG_FILE}" STREQUAL "") + set(config_filename "$ENV{GEMM_PRESHUFFLE_CONFIG_FILE}") + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}") + message(STATUS " Using config from environment variable: ${config_filename}") + elseif(NOT "${GEMM_PRESHUFFLE_CONFIG_FILE}" STREQUAL "") + # Use CMake variable if set + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_PRESHUFFLE_CONFIG_FILE}") + message(STATUS " Using custom config: ${GEMM_PRESHUFFLE_CONFIG_FILE}") + else() + # Use default config for all layouts + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") + message(STATUS " Using default config for layout ${layout}") + endif() + + # Check if config file exists + if(NOT EXISTS ${json_blob}) + message(FATAL_ERROR "Config file not found: ${json_blob}") + endif() + + # Determine number of workers for parallel generation + if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL}) + set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL}) + else() + # Use processor count but limit to avoid memory issues + cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES) + math(EXPR num_workers "${num_cores}") + if(num_workers GREATER 8) + set(num_workers 8) + endif() + endif() + + # Generate individual kernel files using parallel version + message(STATUS "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...") + message(STATUS " Working path: ${working_path}") + message(STATUS " Config file: ${json_blob}") + message(STATUS " Python executable: ${Python3_EXECUTABLE}") + message(STATUS " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_preshuffle_instance_builder.py") + + # Create working directory first + file(MAKE_DIRECTORY ${working_path}) + + # First, just list the kernels (fast operation) + message(STATUS " Listing kernel configurations...") + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_preshuffle_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${json_blob} + --list_kernels + WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} + RESULT_VARIABLE ret + OUTPUT_VARIABLE list_output + ERROR_VARIABLE list_error + ) + + if(NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}") + endif() + + # Read kernel count + if(EXISTS ${working_path}/gemm_preshuffle_kernel_count.txt) + file(READ ${working_path}/gemm_preshuffle_kernel_count.txt kernel_count) + string(STRIP "${kernel_count}" kernel_count) + message(STATUS " Found ${kernel_count} kernel configurations") + else() + message(FATAL_ERROR "Kernel count file not found") + endif() + + # Read kernel list and create targets + if(EXISTS ${working_path}/gemm_preshuffle_kernel_list.txt) + file(STRINGS ${working_path}/gemm_preshuffle_kernel_list.txt kernel_lines) + foreach(line IN LISTS kernel_lines) + # Parse line: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) + + # Create individual target + create_individual_gemm_preshuffle_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}") + endforeach() + else() + message(FATAL_ERROR "Kernel list file not found") + endif() +endfunction() + +# Main build logic - Only individual builds supported +message(STATUS "=== Starting Tile Engine GEMM Preshuffle Configuration ===") +message(STATUS "GEMM_PRESHUFFLE_DATATYPE: ${GEMM_PRESHUFFLE_DATATYPE}") +message(STATUS "GEMM_PRESHUFFLE_LAYOUT: ${GEMM_PRESHUFFLE_LAYOUT}") +message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + +# Filter GPU targets to only gfx90a, gfx942, and gfx950 +set(GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL "") +set(DESIRED_TARGETS "gfx90a;gfx942;gfx950") + +foreach(target IN LISTS SUPPORTED_GPU_TARGETS) + if(target IN_LIST DESIRED_TARGETS) + list(APPEND GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL ${target}) + message(STATUS " Adding GPU target: ${target}") + endif() +endforeach() + +# Skip build if no matching targets found +if(NOT GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL) + message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") +else() + message(STATUS "Building individual GEMM Preshuffle targets for GPU targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}") + + # Enable parallel compilation optimizations + # Set up job pools for better parallel compilation control + set_property(GLOBAL PROPERTY JOB_POOLS + compile_heavy=4 # Limit heavy compilations to prevent OOM + compile_normal=16 # Allow more parallel normal compilations + ) + + # Enable compiler cache if available and explicitly requested + # Disabled by default due to permission issues in CI environments + if(ENABLE_CCACHE_GEMM_PRESHUFFLE) + find_program(CCACHE_PROGRAM ccache) + if(CCACHE_PROGRAM) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) + message(STATUS "Using ccache for faster compilation") + else() + message(WARNING "ccache requested but not found") + endif() + else() + message(STATUS "ccache disabled for GEMM Preshuffle ops (use -DENABLE_CCACHE_GEMM_PRESHUFFLE=ON to enable)") + endif() + + # Create master collection targets + add_custom_target(benchmark_gemm_preshuffle_all) + + # Create datatype collection targets + foreach(dt IN LISTS GEMM_PRESHUFFLE_DATATYPE) + add_custom_target(benchmark_gemm_preshuffle_${dt}) + endforeach() + + # Create layout collection targets + foreach(l IN LISTS GEMM_PRESHUFFLE_LAYOUT) + add_custom_target(benchmark_gemm_preshuffle_${l}) + endforeach() + + # Create combined collection targets + foreach(dt IN LISTS GEMM_PRESHUFFLE_DATATYPE) + foreach(l IN LISTS GEMM_PRESHUFFLE_LAYOUT) + add_custom_target(benchmark_gemm_preshuffle_${dt}_${l}) + endforeach() + endforeach() + + # Create trait-based collection targets + # These are common trait components used across all GEMM kernels + set(GEMM_PRESHUFFLE_PIPELINES "preshufflev1;preshufflev2") + set(GEMM_PRESHUFFLE_EPILOGUES "default;cshuffle") + set(GEMM_PRESHUFFLE_SCHEDULERS "intrawave;interwave;default") + + foreach(pipeline IN LISTS GEMM_PRESHUFFLE_PIPELINES) + add_custom_target(benchmark_gemm_preshuffle_${pipeline}_pipeline) + endforeach() + + foreach(epilogue IN LISTS GEMM_PRESHUFFLE_EPILOGUES) + add_custom_target(benchmark_gemm_preshuffle_${epilogue}_epilogue) + endforeach() + + foreach(scheduler IN LISTS GEMM_PRESHUFFLE_SCHEDULERS) + add_custom_target(benchmark_gemm_preshuffle_${scheduler}_scheduler) + endforeach() + + # Build individual targets for each datatype/layout combination + + foreach(dt IN LISTS GEMM_PRESHUFFLE_DATATYPE) + foreach(l IN LISTS GEMM_PRESHUFFLE_LAYOUT) + build_individual_gemm_preshuffle_targets(${dt} ${l}) + endforeach() + endforeach() +endif() diff --git a/tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle.hpp b/tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle.hpp new file mode 100644 index 0000000000..74fccf6bf2 --- /dev/null +++ b/tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle.hpp @@ -0,0 +1,225 @@ +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "gemm_preshuffle_common.hpp" + +//[TODO] Move parts of this File to commons +enum class Metric +{ + LATENCY = 0, + TFLOPS = 1, + BANDWIDTH = 2 +}; + +inline constexpr auto get_metric_name(Metric m) +{ + switch(m) + { + case Metric::LATENCY: return "latency"; + case Metric::TFLOPS: return "tflops"; + case Metric::BANDWIDTH: return "bandwidth"; + default: throw std::invalid_argument("Unsupported metric type"); + } +} + +struct GemmProblem +{ + int split_k_; + int m_, n_, k_; + int stride_a_, stride_b_, stride_c_; + + std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_; + std::string layout_a_, layout_b_, layout_c_; + + bool structured_sparsity_; + + friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem) + { + os << "{\n" + << " \"split_k\":" << problem.split_k_ << ",\n" + << " \"m\":" << problem.m_ << ",\n" + << " \"n\":" << problem.n_ << ",\n" + << " \"k\":" << problem.k_ << ",\n" + << " \"stride_a\":" << problem.stride_a_ << ",\n" + << " \"stride_b\":" << problem.stride_b_ << ",\n" + << " \"stride_c\":" << problem.stride_c_ << ",\n" + << " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n" + << " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n" + << " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n" + << " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n" + << " \"layout_a\":\"" << problem.layout_a_ << "\",\n" + << " \"layout_b\":\"" << problem.layout_b_ << "\",\n" + << " \"layout_c\":\"" << problem.layout_c_ << "\",\n" + << " \"structured_sparsity\":" << (problem.structured_sparsity_ ? "true" : "false") + << "\n" + << "}"; + return os; + } +}; + +struct PerformanceResult +{ + double latency_; + double tflops_; + double bandwidth_; + + static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m) + { + switch(m) + { + case Metric::LATENCY: return a.latency_ < b.latency_; + case Metric::TFLOPS: return a.tflops_ > b.tflops_; + case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_; + default: throw std::invalid_argument("Unsupported metric type"); + } + } + + friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result) + { + os << "{\n" + << " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_ + << ",\n" + << " \"tflops(TFlops)\": " << result.tflops_ << ",\n" + << " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n" + << "}"; + return os; + } +}; + +struct KernelInstance +{ + std::string name_; + GemmProblem problem_; + PerformanceResult perf_result_; + + static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m) + { + return PerformanceResult::compare(a.perf_result_, b.perf_result_, m); + } + + friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj) + { + os << "{\n" + << " \"name\": \"" << obj.name_ << "\",\n" + << " \"problem\": " << obj.problem_ << ",\n" + << " \"perf_result\": " << obj.perf_result_ << "\n" + << "}"; + return os; + } +}; + +struct Setting +{ + int n_warmup_; + int n_repeat_; + bool is_gpu_timer_; + int verify_; + int init_method_; + bool log_; + std::string csv_filename_; + bool flush_cache_; + int rotating_count_; + bool json_output_; +}; + +inline std::string get_rocm_version() +{ + std::ifstream version_file("/opt/rocm/.info/version"); + if(version_file.is_open()) + { + std::string version; + std::getline(version_file, version); + return version; + } + return "Unknown"; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +/// @brief Function to compare the results of the device and host computations +bool compare(std::string instanceName, + ck_tile::index_t K, + ck_tile::index_t kbatch, + ck_tile::HostTensor& c_m_n_dev_result, + ck_tile::HostTensor& c_m_n_ref) +{ + const float max_accumulated_value = + *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "For " << instanceName << " Relative error threshold is " + << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is " + << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl; + + return pass; +} + +/// @brief Function to get the kernel output with reference implementation on CPU/GPU +void gemm_host_reference(int verify, + ck_tile::HostTensor& a_m_k, + ck_tile::HostTensor& b_k_n, + ck_tile::HostTensor& c_m_n_ref, + ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C) +{ + if(verify == 1) + { + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_ref); + } + else if(verify == 2) + { + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes()); + c_m_n_gpu_buf_ref.SetZero(); + + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data()); + } +} diff --git a/tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle_single.cpp b/tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle_single.cpp new file mode 100644 index 0000000000..152e27e77e --- /dev/null +++ b/tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle_single.cpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "gemm_preshuffle_profiler.hpp" +#include "gemm_preshuffle_common.hpp" + +// The kernel header is included via the compile command line with -include flag +// It defines SelectedKernel struct and KERNEL_NAME +// DataTypeTraits are now defined in gemm_common.hpp + +// Create argument parser +inline auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.") + .insert("n", "4096", "The value for n dimension. Default is 4096.") + .insert("k", "2048", "The value for k dimension. Default is 2048.") + .insert("stride_a", "0", "The stride value for tensor A. Default is 0.") + .insert("stride_b", "0", "The stride value for tensor B. Default is 0.") + .insert("stride_c", "0", "The stride value for tensor C. Default is 0.") + .insert("split_k", "1", "The split value for k dimension. Default is 1.") + .insert("verify", + "2", + "The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 " + "for validation on GPU. Default is 0, no validation.") + .insert("log", + "false", + "Whether output kernel instance information or not. Possible values are true or " + "false. Default is false") + .insert( + "warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.") + .insert( + "repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.") + .insert("timer", + "true", + "Whether if the timer is gpu timer or not. Possible values are false or true. " + "Default is true.") + .insert("init", + "0", + "The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 " + "for constant(1). Default is 0, random.") + .insert("flush_cache", + "true", + "To flush cache, possible values are true or false. " + "Default is false.") + .insert("rotating_count", "1000", "number of iterations to rotate the cache. default is 5.") + .insert("metric", + "0", + "Metric with which to measure kernel performance. Set to 0 for latency, 1 for " + "tflops, or 2 for bandwidth. Default is 0, latency.") + .insert("csv_filename", + "", + "The filename of benchmark result. Default is empty (no CSV output).") + .insert("structured_sparsity", + "false", + "Whether use sparsity kernel or not. Possible values are true or false. Default is " + "false") + .insert("json_output", + "false", + "Whether to output results in JSON format only. Possible values are true or false. " + "Default is " + "false"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +void benchmark_gemm_preshuffle_single(const ck_tile::ArgParser& arg_parser) +{ + // Use DataTypeTraits to get the actual type names from the generated header + // The generated header defines ADataType, BDataType, AccDataType, CDataType + std::string dtype_a = DataTypeTraits::name; + std::string dtype_b = DataTypeTraits::name; + std::string dtype_acc = DataTypeTraits::name; + std::string dtype_c = DataTypeTraits::name; + + // Layout names from the layout types + std::string layout_a = ALayout::name; + std::string layout_b = BLayout::name; + std::string layout_c = CLayout::name; + + // Create GemmProblem struct + GemmProblem gemm_problem{arg_parser.get_int("split_k"), + arg_parser.get_int("m"), + arg_parser.get_int("n"), + arg_parser.get_int("k"), + arg_parser.get_int("stride_a"), + arg_parser.get_int("stride_b"), + arg_parser.get_int("stride_c"), + dtype_a, + dtype_b, + dtype_acc, + dtype_c, + layout_a, + layout_b, + layout_c, + arg_parser.get_bool("structured_sparsity")}; + + // Create Setting struct + Setting setting{arg_parser.get_int("warmup"), + arg_parser.get_int("repeat"), + arg_parser.get_bool("timer"), + arg_parser.get_int("verify"), + arg_parser.get_int("init"), + arg_parser.get_bool("log"), + arg_parser.get_str("csv_filename"), + arg_parser.get_bool("flush_cache"), + arg_parser.get_int("rotating_count"), + arg_parser.get_bool("json_output")}; + + // Get the profiler instance + auto& profiler = GemmProfiler::instance(setting); + + try + { + // Create a lambda that wraps the kernel launch + + std::tuple warp_tile_dims = std::make_tuple( + SelectedKernel::WarpTileM, SelectedKernel::WarpTileN, SelectedKernel::WarpTileK); + + auto kernel_func = [](const ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& stream) { + return SelectedKernel::launch(args, stream); + }; + + // Benchmark the kernel + profiler.benchmark(gemm_problem, kernel_func, warp_tile_dims); + + // Select best instance based on metric + profiler.select_best_instance(static_cast(arg_parser.get_int("metric"))); + } + catch(const std::exception& e) + { + std::cerr << "Benchmark failed: " << e.what() << std::endl; + } +} + +int main(int argc, char* argv[]) +{ + try + { + auto [result, parser] = create_args(argc, argv); + if(!result) + return EXIT_FAILURE; + + benchmark_gemm_preshuffle_single(parser); + return 0; + } + catch(const std::exception& e) + { + std::cerr << "Error: " << e.what() << "\n"; + return EXIT_FAILURE; + } +} diff --git a/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py b/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py new file mode 100644 index 0000000000..2bc42f1ce7 --- /dev/null +++ b/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py @@ -0,0 +1,375 @@ +#!/usr/bin/env python +# SPDX-License-Identifier: MIT +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +""" +Validation utilities for GEMM kernel generation. +Extracted from tile_engine_develop for consistency. +""" + +import subprocess +import re +from functools import lru_cache +import logging +from typing import Tuple, List + +# Element size mapping for different data types +ELEMENT_SIZE_MAP = { + "fp16": 2, + "bf16": 2, + "int8": 1, + "fp8": 1, + "bf8": 1, + "int4": 0.5, + "int32": 4, + "fp32": 4, + "fp64": 8, +} + +# [TODO] Handle this while moving code to commons +# Supported warp tile combinations for different GPU architectures and data types +WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx90a": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], + }, + "gfx942": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], + }, + "gfx950": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "bf8_bf8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 64], + [16, 16, 32], + [16, 16, 128], + [32, 32, 64], + ], + }, +} + +# Unsupported trait combinations +TRAIT_UNSUPPORTED_COMBINATIONS = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), +} + + +def element_size(data_type: str) -> float: + """Calculate the size (in bytes) of a single element for given data type.""" + data_type = data_type.lower() + if data_type not in ELEMENT_SIZE_MAP: + raise ValueError(f"Unsupported data type: {data_type}") + return ELEMENT_SIZE_MAP[data_type] + + +GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)") + + +@lru_cache(maxsize=1) +def get_gpu_name_by_id(gpu_id: int = 0) -> str: + """Retrieve GPU name (e.g. gfx90a) by device ID""" + try: + output = subprocess.check_output( + ["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5 + ) + if matches := GPU_NAME_PATTERN.finditer(output): + gpu_list = [m.group(1) for m in matches] + return gpu_list[gpu_id] if gpu_id < len(gpu_list) else "" + + return "" + + except subprocess.CalledProcessError as e: + logging.debug(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}") + except FileNotFoundError: + logging.debug("ROCm tools not installed (requires rocminfo)") + except subprocess.TimeoutExpired: + logging.debug("GPU query timeout (5s)") + except Exception as e: + logging.debug(f"GPU detection error: {str(e)}") + + return "" + + +def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool: + """Check if a trait combination is valid.""" + return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS + + +def validate_warp_configuration(warp_m: int, warp_n: int, warp_k: int) -> bool: + """Validate warp configuration.""" + return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + + +def validate_dimension_alignment( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, +) -> Tuple[bool, List[str]]: + """Check if tile dimensions are properly aligned with warp dimensions.""" + alignment_issues = [] + + if tile_m % (warp_m * warp_tile_m) != 0: + alignment_issues.append( + f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}" + ) + if tile_n % (warp_n * warp_tile_n) != 0: + alignment_issues.append( + f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}" + ) + if tile_k % (warp_k * warp_tile_k) != 0: + alignment_issues.append( + f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}" + ) + + return len(alignment_issues) == 0, alignment_issues + + +def validate_lds_capacity( + tile_m: int, + tile_n: int, + tile_k: int, + a_datatype: str, + b_datatype: str, + pipeline: str, +) -> Tuple[bool, str]: + """Validate LDS capacity requirements.""" + matrix_a_size = (tile_m * tile_k) * element_size(a_datatype) + matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) + total_tile_in_lds = matrix_a_size + matrix_b_size + + max_tile_size = 2**15 if pipeline == "compv4" else 2**16 + + if total_tile_in_lds > max_tile_size: + error_msg = ( + f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > " + f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n" + f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" + f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B" + ) + return False, error_msg + + return True, "" + + +def validate_warp_tile_combination( + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + gpu_name: str = None, +) -> Tuple[bool, str]: + """Validate warp tile combination against GPU-specific supported combinations.""" + if gpu_name is None: + gpu_name = get_gpu_name_by_id(0) + + # Construct the key for looking up supported combinations + warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}" + current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] + + # Check if we have GPU-specific combinations + gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {}) + if not gpu_warp_tile_combinations: + # If GPU not recognized, try to be permissive but log warning + logging.warning(f"No warp tile combinations found for GPU: {gpu_name}") + return True, "" + + # Check if we have combinations for this data type combination + allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, []) + if not allowed_combinations: + # For data type combinations not in the list, be permissive + logging.debug( + f"No warp tile combinations found for data types: {warp_tile_key}" + ) + return True, "" + + # Check if current combination is in the allowed list + if current_combination not in allowed_combinations: + error_msg = ( + f"Invalid warp tile combination: {current_combination} not in allowed list. " + f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}" + ) + return False, error_msg + + return True, "" + + +def is_tile_config_valid( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + pipeline: str, + trait_name: str = None, +) -> bool: + """ + Comprehensive tile configuration validation. + Returns True if configuration is valid, False otherwise. + """ + # Basic sanity checks + if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: + return False + if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: + return False + if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: + return False + + # Check that warp tiles fit within block tiles + if warp_m * warp_tile_m > tile_m: + return False + if warp_n * warp_tile_n > tile_n: + return False + if warp_k * warp_tile_k > tile_k: + return False + + # Validate warp configuration + if not validate_warp_configuration(warp_m, warp_n, warp_k): + logging.debug( + f"Invalid warp configuration: warp_m({warp_m}), warp_n({warp_n}), warp_k({warp_k})" + ) + return False + + # Validate dimension alignment + is_aligned, alignment_issues = validate_dimension_alignment( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + ) + if not is_aligned: + logging.debug( + f"Dimension alignment failed: {', '.join(alignment_issues)}. " + f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by " + f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + ) + return False + + # Validate LDS capacity + lds_valid, lds_error = validate_lds_capacity( + tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline + ) + if not lds_valid: + logging.debug(f"LDS validation failed: {lds_error}") + return False + + # Validate warp tile combination + warp_tile_valid, warp_tile_error = validate_warp_tile_combination( + warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype + ) + if not warp_tile_valid: + logging.debug(f"Warp tile validation failed: {warp_tile_error}") + return False + + return True + + +# [TODO] Handle this while moving code to commons Add more datatype to this function if needed +def get_dtype_string(datatype: str) -> str: + """Get C++ type string for datatype""" + dtype_map = { + "fp16": "ck_tile::fp16_t", + "fp8": "ck_tile::fp8_t", + "bf16": "ck_tile::bf16_t", + "fp32": "float", + "fp64": "double", + } + return dtype_map.get(datatype, "float") + + +LAYOUT_MAP = { + "r": "ck_tile::tensor_layout::gemm::RowMajor", + "c": "ck_tile::tensor_layout::gemm::ColumnMajor", +} + + +def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]: + """ + Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. + """ + code = str(layout_code).strip().lower() + + a_layout = LAYOUT_MAP[code[0]] + b_layout = LAYOUT_MAP[code[1]] + c_layout = LAYOUT_MAP[code[2]] + return a_layout, b_layout, c_layout diff --git a/tile_engine/ops/gemm_preshuffle/configs/default_config.json b/tile_engine/ops/gemm_preshuffle/configs/default_config.json new file mode 100644 index 0000000000..d48b7b0ac2 --- /dev/null +++ b/tile_engine/ops/gemm_preshuffle/configs/default_config.json @@ -0,0 +1,90 @@ +{ + "tile_config": { + "tile_m": { + "values": [ + 128 + ] + }, + "tile_n": { + "values": [ + 128 + ] + }, + "tile_k": { + "values": [ + 128 + ] + }, + "warp_m": { + "values": [ + 1 + ] + }, + "warp_n": { + "values": [ + 4 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 16 + ] + }, + "warp_tile_n": { + "values": [ + 16 + ] + }, + "warp_tile_k": { + "values": [ + 16,32 + ] + } + }, + "trait_config": { + "pipeline": { + "values": [ + "preshufflev1", + "preshufflev2" + ] + }, + "scheduler": { + "values": [ + "interwave", + "intrawave" + ] + }, + "epilogue": { + "values": [ + "default", + "cshuffle" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + }, + "persistent": { + "values": [ + true, + false + ] + } + } +} \ No newline at end of file diff --git a/tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json b/tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json new file mode 100644 index 0000000000..a9937698d3 --- /dev/null +++ b/tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json @@ -0,0 +1,86 @@ +{ + "tile_config": { + "tile_m": { + "values": [ + 128 + ] + }, + "tile_n": { + "values": [ + 128 + ] + }, + "tile_k": { + "values": [ + 64 + ] + }, + "warp_m": { + "values": [ + 1 + ] + }, + "warp_n": { + "values": [ + 4 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 16 + ] + }, + "warp_tile_n": { + "values": [ + 16 + ] + }, + "warp_tile_k": { + "values": [ + 16,32 + ] + } + }, + "trait_config": { + "pipeline": { + "values": [ + "preshufflev2" + ] + }, + "scheduler": { + "values": [ + "intrawave" + ] + }, + "epilogue": { + "values": [ + "default" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + }, + "persistent": { + "values": [ + false + ] + } + } +} \ No newline at end of file diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py new file mode 100755 index 0000000000..0217a439f2 --- /dev/null +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py @@ -0,0 +1,684 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +import sys +import json +import subprocess +import argparse +import csv +import time +from pathlib import Path +from typing import List, Dict, Tuple, Optional + + +class GemmPreshuffleBenchmark: + def __init__(self, build_dir: str, verbose: bool = False): + self.build_dir = Path(build_dir) + self.verbose = verbose + self.results = [] + + def discover_kernels(self) -> List[Path]: + """Find all benchmark_gemm_preshuffle* executables in the build directory""" + bin_dir = self.build_dir / "bin" + if not bin_dir.exists(): + print(f"Error: Binary directory {bin_dir} does not exist") + return [] + + kernels = list(bin_dir.glob("benchmark_gemm_preshuffle*")) + if self.verbose: + print(f"Found {len(kernels)} kernel executables") + for k in kernels: + print(f" - {k.name}") + return kernels + + def extract_kernel_info(self, kernel_path: Path) -> Dict[str, str]: + """Extract comprehensive kernel information from filename""" + name = kernel_path.stem + + # Initialize with basic info + info = { + "executable": str(kernel_path), + "name": name, + "data_type": "unknown", + "layout": "unknown", + "pipeline": "unknown", + "scheduler": "unknown", + "epilogue": "unknown", + } + + # Parse the kernel name pattern: + # benchmark_gemm_preshuffle_fp16_rcr_mem_default_intrawave_False_False_False_False_False_256x256x32_2x2x1_4x64x16 + parts = name.split("_") + + if len(parts) >= 4: + # Extract data type (4rd part after benchmark_gemm_preshuffle_) + info["data_type"] = parts[3] if len(parts) > 2 else "unknown" + + # Extract layout (5th part) + info["layout"] = parts[4] if len(parts) > 3 else "unknown" + + # Extract pipeline (6th part) + info["pipeline"] = parts[5] if len(parts) > 4 else "unknown" + + # Extract epilogue (7th part) + info["epilogue"] = parts[6] if len(parts) > 5 else "unknown" + + # Extract scheduler (8th part) + info["scheduler"] = parts[7] if len(parts) > 6 else "unknown" + + # Extract detailed configuration from the end of the name + config_info = self.parse_detailed_config(name) + info.update(config_info) + + # Generate config ID + info["config_id"] = self.generate_config_id(info) + + return info + + def parse_detailed_config(self, kernel_name: str) -> Dict: + """Parse detailed configuration from kernel name""" + config = { + "tile_sizes": {"tile_m": 0, "tile_n": 0, "tile_k": 0}, + "warp_config": {"warp_m": 0, "warp_n": 0, "warp_k": 0}, + "warp_tile": {"warp_tile_m": 0, "warp_tile_n": 0, "warp_tile_k": 0}, + "optimization_flags": { + "pad_m": False, + "pad_n": False, + "pad_k": False, + "persistent": False, + }, + } + + # Split by underscore and look for patterns + parts = kernel_name.split("_") + + # Look for boolean flags (sequence of True/False values) + bool_sequence = [] + for i, part in enumerate(parts): + if part in ["True", "False"]: + bool_sequence.append(part == "True") + # Continue collecting consecutive boolean values + j = i + 1 + while j < len(parts) and parts[j] in ["True", "False"]: + bool_sequence.append(parts[j] == "True") + j += 1 + break + + # Assign boolean flags if we found them + # Order: pad_m, pad_n, pad_k, persistent (4 flags total) + if len(bool_sequence) >= 4: + config["optimization_flags"]["pad_m"] = bool_sequence[0] + config["optimization_flags"]["pad_n"] = bool_sequence[1] + config["optimization_flags"]["pad_k"] = bool_sequence[2] + config["optimization_flags"]["persistent"] = bool_sequence[3] + + # Look for tile size patterns (e.g., 256x256x32_2x2x1_4x64x16) + # The pattern is: tile_sizes_warp_config_warp_tile + dimension_groups = [] + for part in parts: + if "x" in part and len(part.split("x")) == 3: + try: + dims = [int(x) for x in part.split("x")] + if all(d > 0 for d in dims): + dimension_groups.append(dims) + except ValueError: + continue + + # Assign dimensions based on order and magnitude + if len(dimension_groups) >= 3: + # Sort by magnitude to identify: largest=tile_sizes, smallest=warp_config, middle=warp_tile + sorted_groups = sorted(dimension_groups, key=lambda x: max(x), reverse=True) + + # Largest dimensions = tile sizes + config["tile_sizes"]["tile_m"] = sorted_groups[0][0] + config["tile_sizes"]["tile_n"] = sorted_groups[0][1] + config["tile_sizes"]["tile_k"] = sorted_groups[0][2] + + # Smallest dimensions = warp config + config["warp_config"]["warp_m"] = sorted_groups[2][0] + config["warp_config"]["warp_n"] = sorted_groups[2][1] + config["warp_config"]["warp_k"] = sorted_groups[2][2] + + # Middle dimensions = warp tile + config["warp_tile"]["warp_tile_m"] = sorted_groups[1][0] + config["warp_tile"]["warp_tile_n"] = sorted_groups[1][1] + config["warp_tile"]["warp_tile_k"] = sorted_groups[1][2] + elif len(dimension_groups) == 2: + # If only 2 groups, assign based on magnitude + sorted_groups = sorted(dimension_groups, key=lambda x: max(x), reverse=True) + + # Larger = tile sizes + config["tile_sizes"]["tile_m"] = sorted_groups[0][0] + config["tile_sizes"]["tile_n"] = sorted_groups[0][1] + config["tile_sizes"]["tile_k"] = sorted_groups[0][2] + + # Smaller = warp config + config["warp_config"]["warp_m"] = sorted_groups[1][0] + config["warp_config"]["warp_n"] = sorted_groups[1][1] + config["warp_config"]["warp_k"] = sorted_groups[1][2] + elif len(dimension_groups) == 1: + # Only one group - assume it's tile sizes + config["tile_sizes"]["tile_m"] = dimension_groups[0][0] + config["tile_sizes"]["tile_n"] = dimension_groups[0][1] + config["tile_sizes"]["tile_k"] = dimension_groups[0][2] + + return config + + def generate_config_id(self, info: Dict) -> str: + """Generate a compact config ID from kernel info""" + # Create a compact identifier + parts = [ + info.get("data_type", "unk"), + info.get("layout", "unk"), + info.get("pipeline", "unk"), + info.get("scheduler", "unk"), + ] + + # Add tile configuration if available + tile_sizes = info.get("tile_sizes", {}) + if tile_sizes.get("tile_m", 0) > 0: + tile_str = ( + f"{tile_sizes['tile_m']}x{tile_sizes['tile_n']}x{tile_sizes['tile_k']}" + ) + parts.append(tile_str) + + # Add warp config if available + warp_config = info.get("warp_config", {}) + if warp_config.get("warp_m", 0) > 0: + warp_str = f"w{warp_config['warp_m']}x{warp_config['warp_n']}x{warp_config['warp_k']}" + parts.append(warp_str) + + # Add warp tile if available + warp_tile = info.get("warp_tile", {}) + if warp_tile.get("warp_tile_m", 0) > 0: + warp_tile_str = f"wt{warp_tile['warp_tile_m']}x{warp_tile['warp_tile_n']}x{warp_tile['warp_tile_k']}" + parts.append(warp_tile_str) + + return "_".join(parts) + + def run_kernel(self, kernel_path: Path, params: Dict[str, str]) -> Optional[Dict]: + """Run a single kernel with given parameters and save output to individual JSON file""" + # Create results directory + results_dir = self.build_dir / "results" + results_dir.mkdir(exist_ok=True) + + # Generate unique JSON filename for this kernel + json_file = results_dir / f"{kernel_path.stem}.json" + + cmd = [str(kernel_path)] + + # Add parameters + for key, value in params.items(): + cmd.append(f"-{key}={value}") + + # Add JSON output flag for clean JSON output + cmd.append("-json_output=true") + + if self.verbose: + print(f"Running: {' '.join(cmd)}") + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + + if result.returncode != 0: + print(f"Error running {kernel_path.name}: {result.stderr}") + return None + + # Save raw output to individual JSON file + output = result.stdout.strip() + + if output: + with open(json_file, "w") as f: + f.write(output) + + # Parse the JSON file + return self.parse_json_file(json_file) + else: + print(f"No output from {kernel_path.name}") + return None + + except subprocess.TimeoutExpired: + print(f"Timeout running {kernel_path.name}") + return None + except Exception as e: + print(f"Error running {kernel_path.name}: {e}") + return None + + def parse_json_file(self, json_file: Path) -> Optional[Dict]: + """Parse JSON data from individual kernel output file""" + try: + with open(json_file, "r") as f: + content = f.read().strip() + + # Parse the JSON directly since executables produce clean JSON + data = json.loads(content) + + # Return the complete JSON data as-is, just add some convenience fields + result = data.copy() + if "perf_result" in data: + perf = data["perf_result"] + # Add convenience fields for backward compatibility + result["time_ms"] = perf.get("latency(ms)", 0) + result["tflops"] = perf.get("tflops(TFlops)", 0) + result["bandwidth_gb_s"] = perf.get("bandwidth(GB/s)", 0) + + return result + + except json.JSONDecodeError as e: + if self.verbose: + print(f"Failed to parse JSON from {json_file}: {e}") + return None + except Exception as e: + if self.verbose: + print(f"Error reading JSON file {json_file}: {e}") + return None + + def benchmark_problem_size( + self, + kernels: List[Path], + m: int, + n: int, + k: int, + split_k: int = 1, + verify: int = 0, + warmup: int = 50, + repeat: int = 100, + flush_cache: bool = True, + rotating_count: int = 1000, + ) -> List[Dict]: + """Benchmark all kernels for a specific problem size""" + results = [] + + params = { + "m": m, + "n": n, + "k": k, + "split_k": split_k, + "verify": verify, + "warmup": warmup, + "repeat": repeat, + "flush_cache": str(flush_cache).lower(), + "rotating_count": rotating_count, + } + + print(f"\nBenchmarking M={m}, N={n}, K={k}, split_k={split_k}") + + for kernel_path in kernels: + kernel_info = self.extract_kernel_info(kernel_path) + result = self.run_kernel(kernel_path, params) + + if result: + # Create new structured result format + structured_result = { + "name": kernel_info["name"], # Add name field for compatibility + "config_id": kernel_info["config_id"], + "problem": result.get("problem", {}), + "perf_result": result.get("perf_result", {}), + "config": { + "data_type": kernel_info["data_type"], + "layout": kernel_info["layout"], + "pipeline": kernel_info["pipeline"], + "scheduler": kernel_info["scheduler"], + "epilogue": kernel_info["epilogue"], + "tile_sizes": kernel_info.get("tile_sizes", {}), + "warp_config": kernel_info.get("warp_config", {}), + "warp_tile": kernel_info.get("warp_tile", {}), + "optimization_flags": kernel_info.get("optimization_flags", {}), + }, + "executable": kernel_info["executable"], + # Keep backward compatibility fields + "time_ms": result.get("time_ms", 0), + "tflops": result.get("tflops", 0), + "bandwidth_gb_s": result.get("bandwidth_gb_s", 0), + } + + results.append(structured_result) + + if self.verbose: + print( + f" {kernel_info['config_id']}: {structured_result['tflops']:.2f} TFLOPS, {structured_result['bandwidth_gb_s']:.2f} GB/s, {structured_result['time_ms']:.2f}ms" + ) + + return results + + def find_best_kernel( + self, results: List[Dict], metric: str = "tflops" + ) -> Optional[Dict]: + """Find the best performing kernel based on metric""" + if not results: + return None + + if metric == "tflops": + return max(results, key=lambda x: x.get("tflops", 0)) + elif metric == "time_ms": + return min(results, key=lambda x: x.get("time_ms", float("inf"))) + elif metric == "bandwidth_gb_s": + return max(results, key=lambda x: x.get("bandwidth_gb_s", 0)) + else: + raise ValueError(f"Unknown metric: {metric}") + + def benchmark_sweep( + self, + problem_sizes: List[Tuple[int, int, int]], + split_k_values: List[int] = [1], + verify: bool = False, + warmup: int = 50, + repeat: int = 100, + flush_cache: bool = True, + rotating_count: int = 1000, + ) -> Dict: + """Run comprehensive benchmark sweep""" + kernels = self.discover_kernels() + if not kernels: + print("No kernels found!") + return {} + + all_results = [] + best_kernels = {} + + for m, n, k in problem_sizes: + for split_k in split_k_values: + results = self.benchmark_problem_size( + kernels, + m, + n, + k, + split_k, + verify=2 if verify else 0, + warmup=warmup, + repeat=repeat, + flush_cache=flush_cache, + rotating_count=rotating_count, + ) + + all_results.extend(results) + + # Find best kernel for this configuration + best = self.find_best_kernel(results) + if best: + key = f"m{m}_n{n}_k{k}_splitk{split_k}" + best_kernels[key] = best + print( + f"Best for {key}: {best['name']} ({best['tflops']:.2f} TFLOPS, {best['bandwidth_gb_s']:.2f} GB/s, {best['time_ms']:.2f}ms)" + ) + + self.results = all_results + return best_kernels + + def export_csv(self, filename: str): + """Export all results to CSV""" + if not self.results: + print("No results to export") + return + + # Get all unique keys from results + all_keys = set() + for result in self.results: + all_keys.update(result.keys()) + + # Sort keys for consistent output + fieldnames = sorted(all_keys) + + with open(filename, "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(self.results) + + print(f"Results exported to {filename}") + + def export_best_kernels(self, best_kernels: Dict, filename: str): + """Export best kernel selections to file""" + with open(filename, "w") as f: + f.write("# Best kernel selections\n") + f.write( + "# Format: problem_size -> kernel_name (TFLOPS, bandwidth, latency)\n\n" + ) + + for key, kernel in sorted(best_kernels.items()): + f.write( + f"{key}: {kernel['name']} ({kernel['tflops']:.2f} TFLOPS, {kernel['bandwidth_gb_s']:.2f} GB/s, {kernel['time_ms']:.2f}ms)\n" + ) + + print(f"Best kernels exported to {filename}") + + def export_json(self, filename: str, best_kernels: Dict = None): + """Export all results and best kernels to JSON with comprehensive metadata""" + from datetime import datetime + + # Calculate comprehensive summary statistics for all metrics + successful_results = [r for r in self.results if r.get("tflops", 0) > 0] + + tflops_values = [r.get("tflops", 0) for r in successful_results] + bandwidth_values = [r.get("bandwidth_gb_s", 0) for r in successful_results] + latency_values = [ + r.get("time_ms", 0) for r in successful_results if r.get("time_ms", 0) > 0 + ] + + # Performance breakdown by kernel type + pipeline_stats = {} + scheduler_stats = {} + data_type_stats = {} + + for result in successful_results: + # Get config info from the new structure + config = result.get("config", {}) + + # Pipeline statistics + pipeline = config.get("pipeline", "unknown") + if pipeline not in pipeline_stats: + pipeline_stats[pipeline] = { + "count": 0, + "avg_tflops": 0, + "best_tflops": 0, + } + pipeline_stats[pipeline]["count"] += 1 + pipeline_stats[pipeline]["best_tflops"] = max( + pipeline_stats[pipeline]["best_tflops"], result.get("tflops", 0) + ) + + # Scheduler statistics + scheduler = config.get("scheduler", "unknown") + if scheduler not in scheduler_stats: + scheduler_stats[scheduler] = { + "count": 0, + "avg_tflops": 0, + "best_tflops": 0, + } + scheduler_stats[scheduler]["count"] += 1 + scheduler_stats[scheduler]["best_tflops"] = max( + scheduler_stats[scheduler]["best_tflops"], result.get("tflops", 0) + ) + + # Data type statistics + data_type = config.get("data_type", "unknown") + if data_type not in data_type_stats: + data_type_stats[data_type] = { + "count": 0, + "avg_tflops": 0, + "best_tflops": 0, + } + data_type_stats[data_type]["count"] += 1 + data_type_stats[data_type]["best_tflops"] = max( + data_type_stats[data_type]["best_tflops"], result.get("tflops", 0) + ) + + # Calculate averages for breakdown stats + for stats_dict, field_name in [ + (pipeline_stats, "pipeline"), + (scheduler_stats, "scheduler"), + (data_type_stats, "data_type"), + ]: + for key in stats_dict: + relevant_results = [ + r + for r in successful_results + if r.get("config", {}).get(field_name, "unknown") == key + ] + if relevant_results: + stats_dict[key]["avg_tflops"] = sum( + r.get("tflops", 0) for r in relevant_results + ) / len(relevant_results) + + output_data = { + "benchmark_metadata": { + "timestamp": datetime.now().isoformat(), + "total_kernels_tested": len(self.results), + "unique_kernels": len( + set(r.get("name", "unknown") for r in self.results) + ), + "successful_runs": len(successful_results), + "failed_runs": len(self.results) - len(successful_results), + }, + "performance_summary": { + "tflops_stats": { + "best": max(tflops_values, default=0), + "average": sum(tflops_values) / len(tflops_values) + if tflops_values + else 0, + "min": min(tflops_values, default=0), + "median": sorted(tflops_values)[len(tflops_values) // 2] + if tflops_values + else 0, + }, + "bandwidth_stats": { + "best_gb_s": max(bandwidth_values, default=0), + "average_gb_s": sum(bandwidth_values) / len(bandwidth_values) + if bandwidth_values + else 0, + "min_gb_s": min(bandwidth_values, default=0), + "median_gb_s": sorted(bandwidth_values)[len(bandwidth_values) // 2] + if bandwidth_values + else 0, + }, + "latency_stats": { + "best_ms": min(latency_values, default=0), + "average_ms": sum(latency_values) / len(latency_values) + if latency_values + else 0, + "max_ms": max(latency_values, default=0), + "median_ms": sorted(latency_values)[len(latency_values) // 2] + if latency_values + else 0, + }, + "kernel_type_breakdown": { + "by_pipeline": pipeline_stats, + "by_scheduler": scheduler_stats, + "by_data_type": data_type_stats, + }, + "total_problem_configurations": len(best_kernels) + if best_kernels + else 0, + }, + "kernel_results": self.results, + "best_kernels_by_problem": best_kernels or {}, + } + + with open(filename, "w") as f: + json.dump(output_data, f, indent=2) + + print(f"JSON results exported to {filename}") + print(f" - Total kernels: {len(self.results)}") + print(f" - Successful runs: {len(successful_results)}") + print(f" - Best TFLOPS: {max(tflops_values, default=0):.2f}") + print(f" - Best bandwidth: {max(bandwidth_values, default=0):.2f} GB/s") + print(f" - Best latency: {min(latency_values, default=0):.2f}ms") + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Preshuffle Kernel Benchmarking Tool" + ) + parser.add_argument( + "build_dir", help="Build directory containing kernel executables" + ) + parser.add_argument( + "--problem-sizes", + nargs="+", + default=["1024,1024,1024", "2048,2048,2048", "4096,4096,4096"], + help="Problem sizes as M,N,K tuples", + ) + parser.add_argument( + "--split-k", nargs="+", type=int, default=[1], help="Split-K values to test" + ) + parser.add_argument("--verify", action="store_true", help="Enable verification") + parser.add_argument( + "--csv", + default="gemm_preshuffle_benchmark_results.csv", + help="CSV output filename", + ) + parser.add_argument( + "--best", default="best_kernels.txt", help="Best kernels output filename" + ) + parser.add_argument("--verbose", action="store_true", help="Verbose output") + parser.add_argument( + "--warmup", + type=int, + default=50, + help="Number of warmup iterations (default: 50)", + ) + parser.add_argument( + "--repeat", + type=int, + default=100, + help="Number of benchmark iterations (default: 100)", + ) + parser.add_argument( + "--flush-cache", + action="store_true", + default=True, + help="Enable cache flushing (default: True)", + ) + parser.add_argument( + "--rotating-count", + type=int, + default=1000, + help="Number of iterations to rotate cache (default: 1000)", + ) + parser.add_argument("--json", help="JSON output filename (optional)") + + args = parser.parse_args() + + # Parse problem sizes + problem_sizes = [] + for size_str in args.problem_sizes: + try: + m, n, k = map(int, size_str.split(",")) + problem_sizes.append((m, n, k)) + except ValueError: + print(f"Invalid problem size: {size_str}") + return 1 + + # Create benchmark instance + benchmark = GemmPreshuffleBenchmark(args.build_dir, verbose=args.verbose) + + # Run benchmark sweep + print("Starting GEMM Preshuffle kernel benchmark sweep...") + start_time = time.time() + + best_kernels = benchmark.benchmark_sweep( + problem_sizes=problem_sizes, + split_k_values=args.split_k, + verify=args.verify, + warmup=args.warmup, + repeat=args.repeat, + flush_cache=args.flush_cache, + rotating_count=args.rotating_count, + ) + + elapsed_time = time.time() - start_time + print(f"\nBenchmark completed in {elapsed_time:.2f} seconds") + + # Export results + benchmark.export_csv(args.csv) + benchmark.export_best_kernels(best_kernels, args.best) + + # Export JSON if requested + if args.json: + benchmark.export_json(args.json, best_kernels) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp new file mode 100644 index 0000000000..4fb98dc3c2 --- /dev/null +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/pk_int4.hpp" + +//[TODO] This can be moved to commons +// DataTypeTraits for all supported types +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +// Helper function to determine if a layout is row-major +template +constexpr auto is_row_major(Layout) +{ + return ck_tile::bool_constant>{}; +} + +// // Permutation function for pk_int4_t +// template +// void permute_vectors_i4x4_b(Tensor& tensor) +// { +// const ck_tile::index_t K = tensor.get_length(0); +// const ck_tile::index_t N = tensor.get_length(1); +// // vector pk_i4x4 permute +// for(int i = 0; i < N; i++) +// { +// for(int j = 0; j < K; j += 8) +// { +// int8_t input[8]; + +// for(int k = 0; k < 4; k++) +// { +// int8_t i4x2 = tensor(j + k * 2, i).data; +// input[k * 2 + 0] = (i4x2 >> 4) & 0xf; +// input[k * 2 + 1] = (i4x2 >> 0) & 0xf; +// } + +// // permute 01234567->20643175 +// { +// int8_t hi = input[2]; +// int8_t lo = input[0]; +// int8_t i4x2 = (hi << 4) | lo; +// tensor(j + 0, i) = i4x2; +// } + +// { +// int8_t hi = input[6]; +// int8_t lo = input[4]; +// int8_t i4x2 = (hi << 4) | lo; +// tensor(j + 2, i) = i4x2; +// } + +// { +// int8_t hi = input[3]; +// int8_t lo = input[1]; +// int8_t i4x2 = (hi << 4) | lo; +// tensor(j + 4, i) = i4x2; +// } + +// { +// int8_t hi = input[7]; +// int8_t lo = input[5]; +// int8_t i4x2 = (hi << 4) | lo; +// tensor(j + 6, i) = i4x2; +// } +// } +// } +// } + +// Structure to hold kernel traits for dispatcher +struct KernelTraits +{ + std::string pipeline; // preshufflev1, preshufflev2 + std::string scheduler; // intrawave, interwave, default + std::string epilogue; // cshuffle, default + bool pad_m; + bool pad_n; + bool pad_k; + bool persistent; + + // Constructor with defaults + KernelTraits() + : pipeline("preshufflev2"), + scheduler("default"), + epilogue("default"), + pad_m(false), + pad_n(false), + pad_k(false), + persistent(false) + { + } +}; + +// Helper to extract traits from kernel name +inline KernelTraits extract_traits_from_name(const std::string& kernel_name) +{ + KernelTraits traits; + + // Extract pipeline + if(kernel_name.find("preshufflev1") != std::string::npos) + { + traits.pipeline = "preshufflev1"; + } + else if(kernel_name.find("preshufflev2") != std::string::npos) + { + traits.pipeline = "preshufflev2"; + } + + // Extract scheduler + if(kernel_name.find("interwave") != std::string::npos) + { + traits.scheduler = "interwave"; + } + else if(kernel_name.find("intrawave") != std::string::npos) + { + traits.scheduler = "intrawave"; + } + else + { + traits.scheduler = "default"; + } + + // Extract epilogue + if(kernel_name.find("default") != std::string::npos && + kernel_name.find("default_") == std::string::npos) + { + traits.epilogue = "default"; + } + else + { + traits.epilogue = "cshuffle"; + } + + // Padding flags would need to be extracted from the kernel configuration + // For now, we'll leave them as false + + return traits; +} + +template +auto shuffle_b(const ck_tile::HostTensor& t, + ck_tile::index_t N_Warp_Tile, + ck_tile::index_t K_Warp_Tile) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + int divisor = N_Warp_Tile == 32 ? 2 : 4; + ck_tile::HostTensor t_view( + {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); +} diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py new file mode 100644 index 0000000000..87b826467b --- /dev/null +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py @@ -0,0 +1,836 @@ +import argparse +import os +import json +import itertools +import logging +import multiprocessing +import concurrent.futures + +from pathlib import Path + +from commons.validation_utils import ( + is_tile_config_valid, + is_trait_combination_valid, + get_dtype_string, + get_abc_layouts, +) + + +class GemmPreshuffleKernelBuilder: + def __init__(self, working_path, datatype, layout, config_json=None): + self.working_path = Path(working_path) + self.datatype = datatype + self.layout = layout + self.config_json = config_json + + # Create working directory if it doesn't exist + self.working_path.mkdir(parents=True, exist_ok=True) + + # Load configuration + if config_json and os.path.exists(config_json): + with open(config_json, "r") as f: + self.config = json.load(f) + + def write_kernel_list(self): + """Write kernel list to file for CMake to read (with comprehensive validation)""" + # Get configurations using comprehensive validation + tile_configs = self._get_tile_configs(fast_mode=False) + trait_combos = self._generate_trait_combinations() + + kernel_list = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + ( + pipeline, + epilogue, + scheduler, + pad_m, + pad_n, + pad_k, + persistent, + ) = trait_combo + + # Create kernel name with proper boolean capitalization + kernel_name = f"gemm_preshuffle_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" + + # Create tile configuration string + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + kernel_name += f"_{tile_str}" + + kernel_list.append( + { + "name": kernel_name, + "tile_config": tile_config, + "trait_combo": trait_combo, + } + ) + + # Write kernel count + with open(self.working_path / "gemm_preshuffle_kernel_count.txt", "w") as f: + f.write(str(len(kernel_list))) + + # Write kernel list + with open(self.working_path / "gemm_preshuffle_kernel_list.txt", "w") as f: + for kernel in kernel_list: + # Format: kernel_name|tile_config|trait_combo + tile_config = kernel["tile_config"] + trait_combo = kernel["trait_combo"] + + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + trait_str = ( + f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_" + + "_".join(str(x) for x in trait_combo[3:]) + ) + + f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n") + + print(f"Listed {len(kernel_list)} kernel configurations") + + def _get_tile_configs(self, fast_mode=False): + """Get tile configurations for the current datatype and layout""" + if "tile_configs" in self.config: + # Old format + return ( + self.config["tile_configs"].get(self.datatype, {}).get(self.layout, []) + ) + elif "tile_config" in self.config: + # New format - generate combinations from individual parameter values + tile_config = self.config["tile_config"] + + # Get all possible values for each parameter + tile_m_values = tile_config.get("tile_m", {}).get("values", [256]) + tile_n_values = tile_config.get("tile_n", {}).get("values", [256]) + tile_k_values = tile_config.get("tile_k", {}).get("values", [32]) + warp_m_values = tile_config.get("warp_m", {}).get("values", [2]) + warp_n_values = tile_config.get("warp_n", {}).get("values", [2]) + warp_k_values = tile_config.get("warp_k", {}).get("values", [1]) + warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [32]) + warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [32]) + warp_tile_k_values = tile_config.get("warp_tile_k", {}).get("values", [32]) + + # Generate all combinations + configs = [] + for tile_m in tile_m_values: + for tile_n in tile_n_values: + for tile_k in tile_k_values: + for warp_m in warp_m_values: + for warp_n in warp_n_values: + for warp_k in warp_k_values: + for warp_tile_m in warp_tile_m_values: + for warp_tile_n in warp_tile_n_values: + for warp_tile_k in warp_tile_k_values: + # Validate configuration + if self._validate_tile_config( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + fast_mode=fast_mode, + ): + configs.append( + { + "tile_m": tile_m, + "tile_n": tile_n, + "tile_k": tile_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "warp_tile_m": warp_tile_m, + "warp_tile_n": warp_tile_n, + "warp_tile_k": warp_tile_k, + } + ) + return configs + else: + # Fallback to default + return [] + + def _generate_trait_combinations(self): + """Generate all combinations of traits""" + if "traits" in self.config: + # Old format + traits = self.config["traits"] + pipelines = traits["pipelines"] + epilogues = traits["epilogues"] + schedulers = traits["schedulers"] + + padding = self.config["padding"] + persistent = self.config["persistent"] + + all_combinations = list( + itertools.product( + pipelines, + epilogues, + schedulers, + padding["pad_m"], + padding["pad_n"], + padding["pad_k"], + persistent, + ) + ) + + # Filter out unsupported trait combinations + combinations = [] + for combo in all_combinations: + pipeline, epilogue, scheduler = combo[:3] + if is_trait_combination_valid(pipeline, epilogue, scheduler): + combinations.append(combo) + else: + logging.debug( + f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}" + ) + + elif "trait_config" in self.config: + # New format + trait_config = self.config["trait_config"] + + pipelines = trait_config.get("pipeline", {}).get("values", ["preshufflev2"]) + epilogues = trait_config.get("epilogue", {}).get("values", ["default"]) + schedulers = trait_config.get("scheduler", {}).get("values", ["default"]) + pad_m_values = trait_config.get("pad_m", {}).get("values", [False]) + pad_n_values = trait_config.get("pad_n", {}).get("values", [False]) + pad_k_values = trait_config.get("pad_k", {}).get("values", [False]) + persistent_values = trait_config.get("persistent", {}).get( + "values", [False] + ) + + all_combinations = list( + itertools.product( + pipelines, + epilogues, + schedulers, + pad_m_values, + pad_n_values, + pad_k_values, + persistent_values, + ) + ) + + # Filter out unsupported trait combinations + combinations = [] + for combo in all_combinations: + pipeline, epilogue, scheduler = combo[:3] + if is_trait_combination_valid(pipeline, epilogue, scheduler): + combinations.append(combo) + else: + logging.debug( + f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}" + ) + else: + # Fallback to minimal default + combinations = [ + ("preshufflev2", "default", "default", False, False, False, False) + ] + + return combinations + + def _validate_tile_config( + self, + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + pipeline="preshufflev2", # Default pipeline for validation + fast_mode=False, # Add fast mode option + ): + """Validate that tile configuration is reasonable""" + if fast_mode: + # Fast validation for listing - only basic sanity checks + if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: + return False + if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: + return False + if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: + return False + + # Basic divisibility check + if tile_m % (warp_m * warp_tile_m) != 0: + return False + if tile_n % (warp_n * warp_tile_n) != 0: + return False + if tile_k % (warp_k * warp_tile_k) != 0: + return False + + return True + else: + # Full validation for generation + # Determine data types for validation + a_datatype = self.datatype + b_datatype = self.datatype + c_datatype = self.datatype + + # Special handling for certain data types + if self.datatype in ["fp8", "bf8"]: + c_datatype = "fp16" + + # Use the comprehensive validation function + return is_tile_config_valid( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + pipeline, + ) + + def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True): + """Generate a single kernel instance""" + ( + pipeline, + epilogue, + scheduler, + pad_m, + pad_n, + pad_k, + persistent, + ) = trait_combo + + # Create kernel name with proper boolean capitalization + kernel_name = f"gemm_preshuffle_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" + + # Create tile configuration string + tile_str = ( + f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + ) + tile_str += ( + f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + ) + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + kernel_name += f"_{tile_str}" + + # Map pipeline names to the correct pipeline implementation + pipeline_impl_map = { + "preshufflev1": "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1", + "preshufflev2": "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2", + } + + # Map pipeline names to base pipeline for hot loop detection + base_pipeline_map = { + "preshufflev1": "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1", + "preshufflev2": "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2", + } + + # Map scheduler names to the correct enum values + scheduler_type_map = { + "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", + "interwave": "ck_tile::GemmPipelineScheduler::Interwave", + "default": "ck_tile::GemmPipelineScheduler::Default", + } + + # Determine accumulator type based on datatype + acc_type = "float" + + # Determine output type + c_type = get_dtype_string(self.datatype) + if self.datatype in ["fp8", "bf8"]: + c_type = "ck_tile::fp16_t" + + # Determine layouts based on self.layout + a_layout, b_layout, c_layout = get_abc_layouts(self.layout) + + # Generate kernel instance code using the correct API + pragma_line = "#pragma once\n" if is_header else "" + instance_code = f"""// Generated kernel instance for {kernel_name} +{pragma_line} +#include +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" + +using ADataType = {get_dtype_string(self.datatype)}; +using BDataType = {get_dtype_string(self.datatype)}; +using AccDataType = {acc_type}; +using CDataType = {c_type}; + +using ALayout = {a_layout}; +using BLayout = {b_layout}; +using CLayout = {c_layout}; + +// Kernel name for display +constexpr const char* KERNEL_NAME = "{kernel_name}"; + +// Wrapper for simplified launch interface +struct SelectedKernel {{ + // Tile configuration + static constexpr ck_tile::index_t BlockSize = 256; + static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]}; + static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]}; + static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]}; + static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]}; + static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]}; + static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]}; + static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]}; + static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]}; + static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]}; + + // Traits + static constexpr bool kPadM = {"true" if pad_m == "true" else "false"}; + static constexpr bool kPadN = {"true" if pad_n == "true" else "false"}; + static constexpr bool kPadK = {"true" if pad_k == "true" else "false"}; + static constexpr bool TransposeC = false; + static constexpr bool UsePersistentKernel = {"true" if persistent == "true" else "false"}; + static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "preshufflev2" else "false"}; + static constexpr bool UseStructuredSparsity = false; + static constexpr bool Preshuffle = true; + static constexpr ck_tile::index_t NumWaveGroups = 1; + + // Tile shape + using TileShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence, + false, false>; + + // Tile partitioner + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; + + // Traits + using Traits = ck_tile::TileGemmTraits; + + // Pipeline problem + using GemmPipelineProblem = ck_tile::GemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + Traits>; + + // Base pipeline for hot loop detection + using BaseGemmPipeline = {base_pipeline_map.get(pipeline, "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2")}; + + static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ + const ck_tile::index_t k_grain = args.k_batch * TileK; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{ + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = {scheduler_type_map.get(scheduler, "ck_tile::GemmPipelineScheduler::Default")}; + [[maybe_unused]] constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + ck_tile::TileGemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v>; + + using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2")}; + + // Epilogue +""" + + # Add epilogue configuration based on type + if epilogue == "cshuffle": + instance_code += """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, // kM_ + TilePartitioner::NPerBlock, // kN_ + WarpPerBlock_M, // MWave_ + WarpPerBlock_N, // NWave_ + WarpTileM, // MPerXdl_ + WarpTileN, // NPerXdl_ + WarpTileK, // KPerXdl_ + TransposeC, // isCTransposed_ + memory_operation, // MemoryOperation_ + NumWaveGroups>; // kNumWaveGroups_ + + using GemmEpilogue = ck_tile::CShuffleEpilogue; +""" + else: # default epilogue + instance_code += """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, // kM_ + TilePartitioner::NPerBlock, // kN_ + kPadM, + kPadN, + WarpTileM, // kMPerXdl_ + WarpTileN, // kNPerXdl_ + WarpTileK, // kKPerXdl_ + TransposeC>; // isCTransposed_ + + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue; +""" + + instance_code += f""" + + // Kernel type + using GemmKernel = ck_tile::GemmKernel; + + // Make kernel arguments + auto kargs = GemmKernel::MakeKernelArgs(args); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); + }} + + // Get grid and block sizes + const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent == "true" else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; + const dim3 blocks = GemmKernel::BlockSize(); + + if(stream.log_level_ > 0) {{ + std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n' + << "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" + << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" + << std::endl; + }} + + // Launch kernel + constexpr int kBlockPerCu = 1; + ave_time = ck_tile::launch_kernel( + stream, + ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }}; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{ + if(args.k_batch == 1) {{ + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{{}}); + }} else {{ + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{{}}); + }} + }}; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; + }} +}}; +""" + + return kernel_name, instance_code + + def run(self, num_workers=None): + """Run the builder to generate individual kernel files""" + # Generate individual kernel files + self.generate_individual(num_workers) + + def generate_individual(self, num_workers=None): + """Generate individual kernel files for separate compilation with parallel processing""" + if num_workers is None: + num_workers = min( + multiprocessing.cpu_count(), 8 + ) # Limit to avoid memory issues + + tile_configs = self._get_tile_configs() + trait_combos = self._generate_trait_combinations() + + # Prepare work items for parallel processing + work_items = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + work_items.append( + ( + tile_config, + trait_combo, + self.working_path, + self.datatype, + self.layout, + ) + ) + + print( + f"Generating {len(work_items)} individual kernel files using {num_workers} workers..." + ) + print(f" Tile configs: {len(tile_configs)}") + print(f" Trait combinations: {len(trait_combos)}") + print(f" Total kernels: {len(work_items)}") + + # Show first few work items for debugging + if work_items: + print(" First work item example:") + tile_config, trait_combo = work_items[0][:2] + print(f" Tile config: {tile_config}") + print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits + + # Process work items in parallel + kernel_list = [] + completed = 0 + + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_workers + ) as executor: + # Submit all work items + print(f" Submitting {len(work_items)} tasks to executor...") + future_to_item = { + executor.submit(_generate_single_kernel_individual, item): item + for item in work_items + } + print(" All tasks submitted, waiting for completion...") + + # Collect results with progress reporting + for future in concurrent.futures.as_completed(future_to_item): + completed += 1 + if completed % 100 == 0 or completed == len(work_items): + print( + f" Progress: {completed}/{len(work_items)} kernels generated" + ) + + try: + result = future.result() + if result: + kernel_list.append(result) + except Exception as exc: + item = future_to_item[future] + print(f"Kernel generation failed for {item}: {exc}") + + # Sort kernel list for consistent ordering + kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name + + # Generate CMake include file for individual targets + self._generate_cmake_individual_targets(kernel_list) + + print( + f"Generated {len(kernel_list)} individual kernel files in {self.working_path}" + ) + + def _generate_cmake_individual_targets(self, kernel_list): + """Generate CMake include file that creates individual targets""" + cmake_code = f"""# Generated CMake file for individual GEMM Preshuffle targets +# Datatype: {self.datatype}, Layout: {self.layout} + +""" + + for kernel_name, trait_combo, tile_config in kernel_list: + pipeline, epilogue, scheduler = trait_combo[:3] + + # Format tile config for CMake function + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + trait_str = f"{pipeline}_{epilogue}_{scheduler}_" + "_".join( + str(x) for x in trait_combo[3:] + ) + + cmake_code += f'create_individual_gemm_preshuffle_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n' + + # Write CMake include file + with open( + self.working_path / "gemm_preshuffle_individual_targets.cmake", "w" + ) as f: + f.write(cmake_code) + + +def _generate_single_kernel_individual(work_item): + """Worker function to generate a single individual kernel file""" + tile_config, trait_combo, working_path, datatype, layout = work_item + + # Create a temporary builder instance for this worker + builder = GemmPreshuffleKernelBuilder(working_path, datatype, layout) + + try: + kernel_name, instance_code = builder._generate_kernel_instance( + tile_config, trait_combo + ) + + # Create simplified filename without the "gemm_" prefix + # Remove "gemm_" from the beginning of kernel_name for the filename + simplified_name = kernel_name + if simplified_name.startswith("gemm_"): + simplified_name = simplified_name[5:] # Remove "gemm_" prefix + + # Write individual header file + header_file = working_path / f"gemm_single_{simplified_name}.hpp" + with open(header_file, "w") as f: + f.write(instance_code) + + return (kernel_name, trait_combo, tile_config) + except Exception as e: + print(f"Error generating individual kernel: {e}") + return None + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM kernel instance builder with parallel support" + ) + parser.add_argument("--working_path", required=True, help="Working directory path") + parser.add_argument( + "--datatype", + required=True, + choices=["fp16", "fp8", "bf16", "bf8"], + help="Data type", + ) + parser.add_argument( + "--layout", + required=True, + choices=["rcr", "rrr", "ccr", "crr"], + help="Matrix layout", + ) + parser.add_argument("--config_json", required=True, help="Configuration JSON file") + parser.add_argument( + "--num_workers", type=int, help="Number of parallel workers (default: auto)" + ) + parser.add_argument( + "--gen_individual", action="store_true", help="Generate individual kernel files" + ) + parser.add_argument( + "--gen_single", action="store_true", help="Generate a single kernel file" + ) + parser.add_argument("--kernel_name", help="Kernel name for single generation") + parser.add_argument( + "--tile_config", help="Tile configuration string for single generation" + ) + parser.add_argument( + "--trait_combo", help="Trait combination string for single generation" + ) + parser.add_argument( + "--list_kernels", + action="store_true", + help="List kernel configurations without generating files", + ) + + args = parser.parse_args() + + assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], ( + f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])" + ) + + layout_parts = args.layout.lower() + assert len(layout_parts) == 3, ( + f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)" + ) + assert layout_parts[0] == "r" and layout_parts[1] == "c", ( + f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a must be 'r' for row major and matrix_b must be 'c' for column major as it is the only supported layout for preshuffle)" + ) + assert layout_parts[2] == "r", ( + f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)" + ) + + # Create builder + builder = GemmPreshuffleKernelBuilder( + args.working_path, args.datatype, args.layout, args.config_json + ) + + if args.list_kernels: + # Fast listing mode - just write kernel list without generating files + builder.write_kernel_list() + pass + elif args.gen_single: + # Generate a single kernel file + if not args.kernel_name or not args.tile_config or not args.trait_combo: + parser.error( + "--gen_single requires --kernel_name, --tile_config, and --trait_combo" + ) + # Parse tile config + tile_parts = args.tile_config.split("_") + tile_dims = tile_parts[0].split("x") + warp_dims = tile_parts[1].split("x") + warp_tile_dims = tile_parts[2].split("x") + + tile_config = { + "tile_m": int(tile_dims[0]), + "tile_n": int(tile_dims[1]), + "tile_k": int(tile_dims[2]), + "warp_m": int(warp_dims[0]), + "warp_n": int(warp_dims[1]), + "warp_k": int(warp_dims[2]), + "warp_tile_m": int(warp_tile_dims[0]), + "warp_tile_n": int(warp_tile_dims[1]), + "warp_tile_k": int(warp_tile_dims[2]), + } + + # Parse trait combo + trait_parts = args.trait_combo.split("_") + trait_combo = ( + trait_parts[0], # pipeline + trait_parts[1], # epilogue + trait_parts[2], # scheduler + trait_parts[3] == "True", # pad_m + trait_parts[4] == "True", # pad_n + trait_parts[5] == "True", # pad_k + trait_parts[6] == "True", # persistent + ) + + # Generate the kernel + kernel_name, instance_code = builder._generate_kernel_instance( + tile_config, trait_combo + ) + + # Write the file + simplified_name = kernel_name + if simplified_name.startswith("gemm_preshuffle_"): + simplified_name = simplified_name[16:] + + header_file = ( + builder.working_path / f"gemm_preshuffle_single_{simplified_name}.hpp" + ) + with open(header_file, "w") as f: + f.write(instance_code) + + print(f"Generated {header_file}") + + elif args.gen_individual: + # Generate all individual kernel files + builder.run(args.num_workers) + pass + else: + parser.error( + "Must specify one of: --list_kernels, --gen_individual, or --gen_single" + ) + + +if __name__ == "__main__": + main() diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp new file mode 100644 index 0000000000..4f2a929ba0 --- /dev/null +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp @@ -0,0 +1,275 @@ +#pragma once + +#include "ck_tile/host/device_prop.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "benchmark_gemm_preshuffle.hpp" + +class GemmProfiler +{ + public: + static GemmProfiler& instance(Setting setting) + { + static GemmProfiler instance{setting}; + return instance; + } + + // Overload for single kernel benchmarking + void benchmark(GemmProblem& gemm_problem, + std::function + kernel_func, + const std::tuple& warp_tile_dims) + { + // Create a vector with a single callable that returns both name and time + std::vector(ck_tile::GemmHostArgs&, + const ck_tile::stream_config&)>> + callables; + + callables.push_back( + [kernel_func](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) { + float time = kernel_func(args, stream); + return std::make_tuple(std::string(KERNEL_NAME), time); + }); + + benchmark(gemm_problem, callables, warp_tile_dims); + } + + void benchmark(GemmProblem& gemm_problem, + std::vector( + ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables, + const std::tuple& warp_tile_dims) + { + const ALayout layout_a = ALayout{}; + const BLayout layout_b = BLayout{}; + const CLayout layout_c = CLayout{}; + + gemm_problem.stride_a_ = ck_tile::get_default_stride( + gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a)); + gemm_problem.stride_b_ = ck_tile::get_default_stride( + gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b)); + gemm_problem.stride_c_ = ck_tile::get_default_stride( + gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)); + + ck_tile::HostTensor a_m_k(ck_tile::host_tensor_descriptor( + gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a))); + ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( + gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b))); + ck_tile::HostTensor c_m_n_dev_result(ck_tile::host_tensor_descriptor( + gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); + + if(setting_.init_method_ == 0) + { + ck_tile::FillUniformDistribution{-.5f, .5f}(a_m_k); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_k_n); + } + else if(setting_.init_method_ == 1) + { + ck_tile::FillMonotonicSeq{}(a_m_k); + ck_tile::FillMonotonicSeq{}(b_k_n); + } + else if(setting_.init_method_ == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n); + } + else + { + a_m_k.SetZero(); + b_k_n.SetZero(); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + // Reference Verification + ck_tile::HostTensor c_m_n_ref(ck_tile::host_tensor_descriptor( + gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); + c_m_n_ref.SetZero(); + + if(setting_.verify_) + { + gemm_host_reference(setting_.verify_, + a_m_k, + b_k_n, + c_m_n_ref, + a_m_k_dev_buf, + b_k_n_dev_buf, + gemm_problem.m_, + gemm_problem.n_, + gemm_problem.k_, + gemm_problem.stride_a_, + gemm_problem.stride_b_, + gemm_problem.stride_c_); + } + + // Kerenl Execution + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + for(const auto& callable : callables) + { + ck_tile::index_t N_Warp_Tile = std::get<1>(warp_tile_dims); + ck_tile::index_t K_Warp_Tile = std::get<2>(warp_tile_dims); + + ck_tile::HostTensor b_shuffle_host = + shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile); + b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); + + ck_tile::GemmHostArgs gemm_args = { + a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + gemm_problem.split_k_, + gemm_problem.m_, + gemm_problem.n_, + gemm_problem.k_, + gemm_problem.stride_a_, + gemm_problem.stride_b_, + gemm_problem.stride_c_, + }; + + auto kernel_run_result = callable(gemm_args, + ck_tile::stream_config{nullptr, + true, + setting_.log_, + setting_.n_warmup_, + setting_.n_repeat_, + setting_.is_gpu_timer_, + setting_.flush_cache_, + setting_.rotating_count_}); + + process_result( + gemm_problem, c_m_n_dev_buf, c_m_n_ref, c_m_n_dev_result, kernel_run_result); + } + } + + void process_result(const GemmProblem& gemm_problem, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_ref, + ck_tile::HostTensor& c_m_n_dev_result, + const std::tuple& kernel_run_result) + { + auto [name, avg_time] = kernel_run_result; + + KernelInstance kernel_instance{name, gemm_problem, {-1.0f, -1.0f, -1.0f}}; + + // compute performance metric + std::size_t flop = std::size_t(2) * gemm_problem.m_ * gemm_problem.n_ * gemm_problem.k_; + std::size_t num_byte = sizeof(ADataType) * gemm_problem.m_ * gemm_problem.k_ + + sizeof(BDataType) * gemm_problem.n_ * gemm_problem.k_ + + sizeof(CDataType) * gemm_problem.m_ * gemm_problem.n_; + + // update + kernel_instance.perf_result_.latency_ = avg_time; + kernel_instance.perf_result_.tflops_ = static_cast(flop) / 1.E9 / avg_time; + kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time; + + if(setting_.log_ > 0 && !setting_.json_output_) + { + std::cout << kernel_instance << std::endl; + } + + // verify result + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + bool verified_correct = + !setting_.verify_ || + compare(name, gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_ref); + + if(verified_correct) + { + kernel_instances_.emplace_back(kernel_instance); + } + else + { + std::cout << "Verification failed, skip kernel: " << name << std::endl; + } + + // clear tensor + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + } + + KernelInstance select_best_instance(Metric metric) + { + if(kernel_instances_.empty()) + throw std::runtime_error("Empty instances"); + + auto kernel_instance = *std::max_element(kernel_instances_.begin(), + kernel_instances_.end(), + [metric](const auto& a, const auto& b) { + return PerformanceResult::compare( + b.perf_result_, a.perf_result_, metric); + }); + + if(setting_.json_output_) + { + // Output clean JSON only + std::cout << kernel_instance << std::endl; + } + else + { + std::cout << "**********************************" << std::endl; + std::cout << "According to given metrics: " << get_metric_name(metric) << "\n" + << "Current kernel performance is: " << kernel_instance << std::endl; + std::cout << "**********************************" << std::endl; + } + + if(!setting_.csv_filename_.empty()) + { + std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app); + + if(!file.is_open()) + { + std::cerr << "Warning: Failed to open CSV file for writing." << std::endl; + } + else + { + if(file.tellp() == 0) + { + file << "rocm_version,device_name," + << "split_k,m,n,k,stride_a,stride_b,stride_c," + << "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c," + << "structured_sparsity," << "name," + << "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n"; + } + + const auto& problem = kernel_instance.problem_; + const auto& name = kernel_instance.name_; + const auto& perf = kernel_instance.perf_result_; + + file << get_rocm_version() << "," << ck_tile::get_device_name() << "," + << problem.split_k_ << "," << problem.m_ << "," << problem.n_ << "," + << problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << "," + << problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_ + << "," << problem.dtype_acc_ << "," << problem.dtype_c_ << "," + << problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_ + << "," << problem.structured_sparsity_ << "," << name << "," << std::fixed + << std::setprecision(4) << perf.latency_ << "," << std::fixed + << std::setprecision(4) << perf.tflops_ << "," << std::fixed + << std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric) + << "\n"; + + if(!file) + { + std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl; + } + } + } + + return kernel_instance; + } + + GemmProfiler(const GemmProfiler&) = delete; + GemmProfiler& operator=(const GemmProfiler&) = delete; + + private: + ~GemmProfiler() { kernel_instances_.clear(); } + GemmProfiler(Setting setting) : setting_(setting) {} + + Setting setting_; + + std::vector kernel_instances_; +};