diff --git a/Jenkinsfile b/Jenkinsfile index ce88294567..aa4045186e 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1642,14 +1642,9 @@ pipeline { ninja -j64 benchmark_gemm_preshuffle_all && \ python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" \ --warmup 5 --repeat 5 --verbose --json results.json && \ - ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \ - ./bin/benchmark_gemm_multi_d_fp16_rrrr && \ - ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \ - ./bin/benchmark_gemm_multi_d_fp16_ccrr && \ - ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \ - ./bin/benchmark_gemm_multi_d_fp16_crrr && \ - ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \ - ./bin/benchmark_gemm_multi_d_fp16_rcrr """ + ninja -j64 benchmark_gemm_multi_d_all && \ + python3 ../tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" \ + --warmup 5 --repeat 5 --verbose --json results.json """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -1682,14 +1677,9 @@ pipeline { ninja -j64 benchmark_gemm_preshuffle_all && \ python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" \ --warmup 5 --repeat 5 --verbose --json results.json && \ - ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \ - ./bin/benchmark_gemm_multi_d_fp16_rrrr && \ - ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \ - ./bin/benchmark_gemm_multi_d_fp16_ccrr && \ - ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \ - ./bin/benchmark_gemm_multi_d_fp16_crrr && \ - ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \ - ./bin/benchmark_gemm_multi_d_fp16_rcrr """ + ninja -j64 benchmark_gemm_multi_d_all && \ + python3 ../tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" \ + --warmup 5 --repeat 5 --verbose --json results.json """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) diff --git a/tile_engine/ops/gemm/commons/test_benchmark.sh b/tile_engine/ops/commons/test_benchmark.sh similarity index 100% rename from tile_engine/ops/gemm/commons/test_benchmark.sh rename to tile_engine/ops/commons/test_benchmark.sh diff --git a/tile_engine/ops/gemm/commons/test_validation.py b/tile_engine/ops/commons/test_validation.py similarity index 100% rename from tile_engine/ops/gemm/commons/test_validation.py rename to tile_engine/ops/commons/test_validation.py diff --git a/tile_engine/ops/gemm/commons/validation_utils.py b/tile_engine/ops/commons/validation_utils.py similarity index 95% rename from tile_engine/ops/gemm/commons/validation_utils.py rename to tile_engine/ops/commons/validation_utils.py index 3077ae4ba0..3eb7bf8b57 100644 --- a/tile_engine/ops/gemm/commons/validation_utils.py +++ b/tile_engine/ops/commons/validation_utils.py @@ -125,38 +125,13 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { [32, 32, 64], ], }, - "gfx1201": { + "gfx1201": { # Check how to handle for GEMM and Multi D "fp16_fp16_fp16": [ [16, 16, 16], ], }, } -# Supported warp tile combinations for different GPU architectures and data types -WARP_SUPPORTED_COMBINATIONS = { - "gfx90a": [ - [1, 4, 1], - [2, 2, 1], - [4, 1, 1], - ], - "gfx942": [ - [1, 4, 1], - [2, 2, 1], - [4, 1, 1], - ], - "gfx950": [ - [1, 4, 1], - [2, 2, 1], - [4, 1, 1], - ], - "gfx1201": [ - [2, 4, 1], - [1, 8, 1], - [8, 1, 1], - [4, 2, 1], - ], -} - # Unsupported trait combinations TRAIT_UNSUPPORTED_COMBINATIONS = { ("compv3", "cshuffle", "interwave"), @@ -441,6 +416,20 @@ def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]: return a_layout, b_layout, c_layout +def get_abcd_layouts(layout_code: str) -> Tuple[str, str, str, List[str]]: + """ + Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcrr', 'ccrr', 'crrr', 'rrrr'. + """ + code = str(layout_code).strip().lower() + + a_layout = LAYOUT_MAP[code[0]] + b_layout = LAYOUT_MAP[code[1]] + c_layout = LAYOUT_MAP[code[2]] + d0_layout = LAYOUT_MAP[code[3]] + d1_layout = LAYOUT_MAP[code[3]] + return a_layout, b_layout, c_layout, [d0_layout, d1_layout] + + def validate_whole_wg_cover_configuration( tile_m, tile_n, @@ -464,13 +453,13 @@ def validate_whole_wg_cover_configuration( # A matrix validation if layout[0] == "r": - XPerTile = tile_k - YPerTile = tile_m - vector_load_size = get_global_vector_load_size( BlockSize, tile_k, a_datatype, tile_m, tile_k ) + XPerTile = tile_k + YPerTile = tile_m + elif layout[0] == "c": vector_load_size = get_global_vector_load_size( BlockSize, tile_k, a_datatype, tile_m, tile_m @@ -485,7 +474,6 @@ def validate_whole_wg_cover_configuration( ) if not wg_cover_core_valid: - print("I am here 1") logging.debug( f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}" ) @@ -521,7 +509,7 @@ def validate_whole_wg_cover_configuration( if not wg_cover_core_valid: print("I am here 3") logging.debug( - f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}" + f"whole workgroup cover failed for Matrix B distribution: {wg_cover_core_error}" ) return False, wg_cover_core_error @@ -540,7 +528,6 @@ def validate_whole_wg_cover_configuration( XPerTile, YPerTile, BlockSize, vector_load_size, warp_size ) if not wg_cover_core_valid: - print("I am here 4") logging.debug( f"whole workgroup cover failed for Matrix B: {wg_cover_core_error}" ) @@ -557,7 +544,7 @@ def wg_cover_core_validation( warp_size: int, ) -> Tuple[bool, str]: if XPerTile % vector_load_size != 0: - return False + return False, "XPerTile is not divisible by vector_load_size" num_warps = BlockSize / warp_size LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size) @@ -567,7 +554,7 @@ def wg_cover_core_validation( Y1 = warp_size // X0 if X0 * Y1 != warp_size: - return False, "" + return False, "X0 * Y1 != warp_size" return True, "" @@ -583,9 +570,9 @@ def get_global_vector_load_size( PackedSize = 1 if ( - XPerTile % (PackedSize * 32 / element_size(DataType)) == 0 + PackedSize == 2 + and XPerTile % (PackedSize * 32 / element_size(DataType)) == 0 and elements_per_thread % (PackedSize * 32 / element_size(DataType)) == 0 - and PackedSize == 2 ): return PackedSize * 32 / element_size(DataType) elif ( diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index 3c18fc4952..a72b6c40ab 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -122,15 +122,15 @@ function(build_individual_gemm_targets datatype layout) if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "") set(config_filename "$ENV{GEMM_CONFIG_FILE}") set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}") - message(STATUS " Using config from environment variable: ${config_filename}") + message(VERBOSE " Using config from environment variable: ${config_filename}") elseif(NOT "${GEMM_CONFIG_FILE}" STREQUAL "") # Use CMake variable if set set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_CONFIG_FILE}") - message(STATUS " Using custom config: ${GEMM_CONFIG_FILE}") + message(VERBOSE " Using custom config: ${GEMM_CONFIG_FILE}") else() # Use default config for all layouts set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") - message(STATUS " Using default config for layout ${layout}") + message(VERBOSE " Using default config for layout ${layout}") endif() # Check if config file exists @@ -151,16 +151,16 @@ function(build_individual_gemm_targets datatype layout) endif() # Generate individual kernel files using parallel version - message(STATUS "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...") - message(STATUS " Working path: ${working_path}") - message(STATUS " Config file: ${json_blob}") - message(STATUS " Python executable: ${Python3_EXECUTABLE}") - message(STATUS " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py") + message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...") + message(VERBOSE " Working path: ${working_path}") + message(VERBOSE " Config file: ${json_blob}") + message(VERBOSE " Python executable: ${Python3_EXECUTABLE}") + message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py") # Create working directory first file(MAKE_DIRECTORY ${working_path}) - message(STATUS "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py --working_path ${working_path} --datatype ${datatype} --layout ${layout} @@ -169,7 +169,7 @@ function(build_individual_gemm_targets datatype layout) --list_kernels ") # First, just list the kernels (fast operation) - message(STATUS " Listing kernel configurations...") + message(VERBOSE " Listing kernel configurations...") execute_process( COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py --working_path ${working_path} @@ -192,7 +192,7 @@ function(build_individual_gemm_targets datatype layout) if(EXISTS ${working_path}/gemm_kernel_count.txt) file(READ ${working_path}/gemm_kernel_count.txt kernel_count) string(STRIP "${kernel_count}" kernel_count) - message(STATUS " Found ${kernel_count} kernel configurations") + message(VERBOSE " Found ${kernel_count} kernel configurations") else() message(FATAL_ERROR "Kernel count file not found") endif() @@ -216,10 +216,10 @@ function(build_individual_gemm_targets datatype layout) endfunction() # Main build logic - Only individual builds supported -message(STATUS "=== Starting Tile Engine GEMM Configuration ===") -message(STATUS "GEMM_DATATYPE: ${GEMM_DATATYPE}") -message(STATUS "GEMM_LAYOUT: ${GEMM_LAYOUT}") -message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") +message(VERBOSE "=== Starting Tile Engine GEMM Configuration ===") +message(VERBOSE "GEMM_DATATYPE: ${GEMM_DATATYPE}") +message(VERBOSE "GEMM_LAYOUT: ${GEMM_LAYOUT}") +message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") # Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201 set(GEMM_GPU_TARGETS_INDIVIDUAL "") @@ -228,7 +228,7 @@ set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201") foreach(target IN LISTS SUPPORTED_GPU_TARGETS) if(target IN_LIST DESIRED_TARGETS) list(APPEND GEMM_GPU_TARGETS_INDIVIDUAL ${target}) - message(STATUS " Adding GPU target: ${target}") + message(VERBOSE " Adding GPU target: ${target}") endif() endforeach() @@ -236,7 +236,7 @@ endforeach() if(NOT GEMM_GPU_TARGETS_INDIVIDUAL) message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") else() - message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}") + message(VERBOSE "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}") # Enable parallel compilation optimizations # Set up job pools for better parallel compilation control @@ -251,12 +251,12 @@ else() find_program(CCACHE_PROGRAM ccache) if(CCACHE_PROGRAM) set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) - message(STATUS "Using ccache for faster compilation") + message(VERBOSE "Using ccache for faster compilation") else() message(WARNING "ccache requested but not found") endif() else() - message(STATUS "ccache disabled for GEMM ops (use -DENABLE_CCACHE_GEMM=ON to enable)") + message(VERBOSE "ccache disabled for GEMM ops (use -DENABLE_CCACHE_GEMM=ON to enable)") endif() # Create master collection targets diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 81b25e592f..1aff42b902 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -8,12 +8,30 @@ import multiprocessing import concurrent.futures from pathlib import Path import logging -from commons.validation_utils import ( - is_tile_config_valid, - is_trait_combination_valid, - get_dtype_string, - get_abc_layouts, -) +import importlib.util + + +def _import_validation_utils(): + """Import validation utilities from commons directory.""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + parent_dir = os.path.dirname(current_dir) + + # Load the module dynamically + spec = importlib.util.spec_from_file_location( + "validation_utils", os.path.join(parent_dir, "commons", "validation_utils.py") + ) + validation_utils = importlib.util.module_from_spec(spec) + spec.loader.exec_module(validation_utils) + + return validation_utils + + +# Import validation functions +_validation_utils = _import_validation_utils() +is_tile_config_valid = _validation_utils.is_tile_config_valid +is_trait_combination_valid = _validation_utils.is_trait_combination_valid +get_dtype_string = _validation_utils.get_dtype_string +get_abc_layouts = _validation_utils.get_abc_layouts logging.basicConfig(level=logging.INFO) @@ -563,6 +581,8 @@ struct SelectedKernel {{ tile_configs = self._get_tile_configs() trait_combos = self._generate_trait_combinations() k_block_per_cu = self.config.get("k_block_per_cu") + if k_block_per_cu is None: + k_block_per_cu = 1 # Prepare work items for parallel processing work_items = [] @@ -574,11 +594,12 @@ struct SelectedKernel {{ trait_combo, k_block_per_cu, self.working_path, + self.gpu_target, self.datatype, self.layout, + self.config_json, ) ) - print( f"Generating {len(work_items)} individual kernel files using {num_workers} workers..." ) @@ -615,7 +636,6 @@ struct SelectedKernel {{ print( f" Progress: {completed}/{len(work_items)} kernels generated" ) - try: result = future.result() if result: @@ -662,10 +682,19 @@ struct SelectedKernel {{ def _generate_single_kernel_individual(work_item): """Worker function to generate a single individual kernel file""" - tile_config, trait_combo, k_block_per_cu, working_path, datatype, layout = work_item + ( + tile_config, + trait_combo, + k_block_per_cu, + working_path, + gpu_target, + datatype, + layout, + config_json, + ) = work_item # Create a temporary builder instance for this worker - builder = GemmKernelBuilder(working_path, datatype, layout) + builder = GemmKernelBuilder(working_path, gpu_target, datatype, layout, config_json) try: kernel_name, instance_code = builder._generate_kernel_instance( @@ -798,6 +827,8 @@ def main(): ) k_block_per_cu = builder.config.get("k_block_per_cu") + if k_block_per_cu is None: + k_block_per_cu = 1 # Generate the kernel kernel_name, instance_code = builder._generate_kernel_instance( diff --git a/tile_engine/ops/gemm_multi_d/CMakeLists.txt b/tile_engine/ops/gemm_multi_d/CMakeLists.txt index 01bbab53de..8d9c087e24 100644 --- a/tile_engine/ops/gemm_multi_d/CMakeLists.txt +++ b/tile_engine/ops/gemm_multi_d/CMakeLists.txt @@ -1,175 +1,311 @@ - set(GEMM_MULTI_D_DATATYPE "fp16" CACHE STRING "List of datatypes for GEMM Multi D (semicolon-separated)") -set(GEMM_MULTI_D_LAYOUT "rcrr" CACHE STRING "List of layout for GEMM Multi D(semicolon-separated)") +set(GEMM_MULTI_D_LAYOUT "rcrr;rrrr;crrr;ccrr" CACHE STRING "List of layout for GEMM Multi D (semicolon-separated)") +set(GEMM_MULTI_D_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") set(GEMM_MULTI_D_ELEMENTWISE_FUNCTION "mul" CACHE STRING "Elementwise function") -function(build_gemm_multi_d_for_datatype_layout datatype layout) - # Filter GPU targets to only gfx90a, gfx942, and gfx950 - set(GEMM_GPU_TARGETS "") - set(DESIRED_TARGETS "gfx90a;gfx942;gfx950") - - foreach(target IN LISTS SUPPORTED_GPU_TARGETS) - if(target IN_LIST DESIRED_TARGETS) - list(APPEND GEMM_GPU_TARGETS ${target}) - endif() - endforeach() - - # Skip compilation if no matching targets found - if(NOT GEMM_GPU_TARGETS) - message(WARNING "Skipping Tile Engine GEMM Multi D compilation: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") +option(ENABLE_CCACHE_GEMM_MULTI_D "Enable ccache for GEMM Multi D ops compilation" OFF) + +# Store the directory path for use in functions +set(GEMM_MULTI_D_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) + +# Function to create individual GEMM Multi D targets +function(create_individual_gemm_multi_d_target datatype layout trait tile_config config_json) + # Use the parent scope GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL variable + if(NOT GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL) + message(WARNING "Skipping individual GEMM Multi D target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets") return() endif() - - message(STATUS "Building GEMM Multi D for GPU targets: ${GEMM_GPU_TARGETS}") - + + # Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k + # First split by underscore to get three groups + string(REPLACE "_" ";" config_groups ${tile_config}) + list(GET config_groups 0 tile_dims) # e.g., 256x256x32 + list(GET config_groups 1 warp_dims) # e.g., 4x1x1 + list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16 + + # Parse tile dimensions + string(REPLACE "x" ";" tile_parts ${tile_dims}) + list(GET tile_parts 0 tile_m) + list(GET tile_parts 1 tile_n) + list(GET tile_parts 2 tile_k) + + # Parse warp dimensions + string(REPLACE "x" ";" warp_parts ${warp_dims}) + list(GET warp_parts 0 warp_m) + list(GET warp_parts 1 warp_n) + list(GET warp_parts 2 warp_k) + + # Parse warp tile dimensions + string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims}) + list(GET warp_tile_parts 0 warp_tile_m) + list(GET warp_tile_parts 1 warp_tile_n) + list(GET warp_tile_parts 2 warp_tile_k) + + set(target_name "benchmark_gemm_multi_d_${datatype}_${layout}_${trait}_${tile_config}") set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") - # Comment this if-else block when using user_provided_config - if(layout STREQUAL "rcrr") - set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") + # Generate the single instance header for this kernel + set(instance_header "${working_path}/gemm_multi_d_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") + + # Add custom command to generate the header file at build time + add_custom_command( + OUTPUT ${instance_header} + COMMAND ${Python3_EXECUTABLE} ${GEMM_MULTI_D_SOURCE_DIR}/gemm_multi_d_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION} + --config_json ${config_json} + --gen_single + --kernel_name "gemm_multi_d_${datatype}_${layout}_${trait}_${tile_config}" + --tile_config "${tile_config}" + --trait_combo "${trait}" + --gpu_target "${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}" + DEPENDS ${GEMM_MULTI_D_SOURCE_DIR}/gemm_multi_d_instance_builder.py ${config_json} + COMMENT "Generating ${instance_header}" + ) + + # Create the executable + add_executable(${target_name} + EXCLUDE_FROM_ALL + ${GEMM_MULTI_D_SOURCE_DIR}/gemm_multi_d_benchmark_single.cpp + ${instance_header} + ) + + # Set GPU architectures + set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}) + + # Set compile definitions + target_compile_definitions(${target_name} PRIVATE + GEMM_MULTI_D_SINGLE_INSTANCE_HPP="${instance_header}" + ) + + # Include directories + target_include_directories(${target_name} PRIVATE + ${GEMM_MULTI_D_SOURCE_DIR} + ${working_path} + ) + + # Compile options + target_compile_options(${target_name} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + -include ${instance_header} + ) + + # Add to collection targets + add_dependencies(benchmark_gemm_multi_d_all ${target_name}) + add_dependencies(benchmark_gemm_multi_d_${datatype} ${target_name}) + add_dependencies(benchmark_gemm_multi_d_${layout} ${target_name}) + add_dependencies(benchmark_gemm_multi_d_${datatype}_${layout} ${target_name}) + + # Add to trait-specific targets + string(REPLACE "_" ";" trait_parts ${trait}) + list(GET trait_parts 0 pipeline) + list(GET trait_parts 1 epilogue) + list(GET trait_parts 2 scheduler) + + add_dependencies(benchmark_gemm_multi_d_${pipeline}_pipeline ${target_name}) + add_dependencies(benchmark_gemm_multi_d_${epilogue}_epilogue ${target_name}) + add_dependencies(benchmark_gemm_multi_d_${scheduler}_scheduler ${target_name}) +endfunction() + +# Function to build individual GEMM Multi D targets +function(build_individual_gemm_multi_d_targets datatype layout) + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") + + # Choose config file + # Priority order: + # 1. Environment variable GEMM_MULTI_D_CONFIG_FILE + # 2. CMake variable GEMM_MULTI_D_CONFIG_FILE + # 3. Default based on layout + + # Check environment variable first + if(DEFINED ENV{GEMM_MULTI_D_CONFIG_FILE} AND NOT "$ENV{GEMM_MULTI_D_CONFIG_FILE}" STREQUAL "") + set(config_filename "$ENV{GEMM_MULTI_D_CONFIG_FILE}") + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}") + message(VERBOSE " Using config from environment variable: ${config_filename}") + elseif(NOT "${GEMM_MULTI_D_CONFIG_FILE}" STREQUAL "") + # Use CMake variable if set + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_MULTI_D_CONFIG_FILE}") + message(VERBOSE " Using custom config: ${GEMM_MULTI_D_CONFIG_FILE}") else() - set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/custom_ci_config.json") + # Use default config for all layouts + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") + message(VERBOSE " Using default config for layout ${layout}") endif() - # uncomment this if you want to use user_provided_config.json - # set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json") - - # Generate kernel list - execute_process( - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py + # Check if config file exists + if(NOT EXISTS ${json_blob}) + message(FATAL_ERROR "Config file not found: ${json_blob}") + endif() + + # Determine number of workers for parallel generation + if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL}) + set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL}) + else() + # Use processor count but limit to avoid memory issues + cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES) + math(EXPR num_workers "${num_cores}") + if(num_workers GREATER 8) + set(num_workers 8) + endif() + endif() + + # Generate individual kernel files using parallel version + message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...") + message(VERBOSE " Working path: ${working_path}") + message(VERBOSE " Config file: ${json_blob}") + message(VERBOSE " Python executable: ${Python3_EXECUTABLE}") + message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py") + + # Create working directory first + file(MAKE_DIRECTORY ${working_path}) + + message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py --working_path ${working_path} --datatype ${datatype} --layout ${layout} --elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION} --config_json ${json_blob} - --list_blobs - --gpu_target ${GEMM_GPU_TARGETS} - RESULT_VARIABLE ret - ) - if(NOT ret EQUAL 0) - message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${ret}") - endif() + --gpu_target ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL} + --list_kernels ") - file(STRINGS "${working_path}/gemm_multi_d_instance_blobs.txt" codegen_blobs) - file(STRINGS "${working_path}/gemm_multi_d_instance_blobs_range.txt" codegen_blobs_range) - - # Generate the blobs - add_custom_command( - OUTPUT ${codegen_blobs} - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py - --working_path "${working_path}" + # First, just list the kernels (fast operation) + message(VERBOSE " Listing kernel configurations...") + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py + --working_path ${working_path} --datatype ${datatype} --layout ${layout} --elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION} - --config_json "${json_blob}" - --gen_blobs - --gpu_target ${GEMM_GPU_TARGETS} - COMMENT "Generating GEMM Multi D instance sources for ${datatype} ${layout}" + --config_json ${json_blob} + --gpu_target ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL} + --list_kernels + WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} + RESULT_VARIABLE ret + OUTPUT_VARIABLE list_output + ERROR_VARIABLE list_error ) - add_custom_target(gemm_multi_d_gen_${datatype}_${layout} DEPENDS ${codegen_blobs}) - set(intermediate_libs) - list(LENGTH codegen_blobs codegen_blobs_len) + if(NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}") + endif() - foreach(blob IN LISTS codegen_blobs_range) - string(STRIP "${blob}" stripped_blob) - separate_arguments(spilit_blob UNIX_COMMAND "${stripped_blob}") - # Each line is: - list(GET spilit_blob 0 name) - list(GET spilit_blob 1 first) - list(GET spilit_blob 2 last) - math(EXPR total_files "${last} - ${first}") - if(total_files EQUAL 0) - continue() # nothing for this trait - endif() + # Read kernel count + if(EXISTS ${working_path}/gemm_multi_d_kernel_count.txt) + file(READ ${working_path}/gemm_multi_d_kernel_count.txt kernel_count) + string(STRIP "${kernel_count}" kernel_count) + message(VERBOSE " Found ${kernel_count} kernel configurations") + else() + message(FATAL_ERROR "Kernel count file not found") + endif() - # Object libraries (chunked) per trait - set(sub_intermediate_libs) - set(chunk_size 3) - math(EXPR num_chunks "( ${total_files} + ${chunk_size} - 1 ) / ${chunk_size}") - math(EXPR num_chunks_minus_1 "${num_chunks} - 1") - - foreach(i RANGE 0 ${num_chunks_minus_1}) - math(EXPR start "${first} + ${i} * ${chunk_size} ") - math(EXPR end "${start} + ${chunk_size} - 1") - - set(chunk_files) - foreach(j RANGE ${start} ${end}) - if(j LESS ${last} AND j LESS ${codegen_blobs_len}) - list(GET codegen_blobs ${j} f) - list(APPEND chunk_files "${f}") - endif() - endforeach() - - #list(LENGTH chunk_files chunk_files_len) - #if(chunk_files_len AND chunk_files_len GREATER 1) - if(chunk_files) - set(sub_intermediate_lib_name "gemm_multi_d_objlib_${name}_${i}_${datatype}_${layout}") - add_library(${sub_intermediate_lib_name} OBJECT ${chunk_files}) - set_property(TARGET ${sub_intermediate_lib_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS}) - list(APPEND sub_intermediate_libs ${sub_intermediate_lib_name}) - endif() + # Read kernel list and create targets + if(EXISTS ${working_path}/gemm_multi_d_kernel_list.txt) + file(STRINGS ${working_path}/gemm_multi_d_kernel_list.txt kernel_lines) + foreach(line IN LISTS kernel_lines) + # Parse line: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) + # Create individual target + create_individual_gemm_multi_d_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}") endforeach() - - # ------------------ Bundle the object libs into one static lib --------- - #list(LENGTH sub_intermediate_libs sub_intermediate_libs_len) - #if(sub_intermediate_libs AND sub_intermediate_libs_len GREATER 1) - if(sub_intermediate_libs) - set(intermediate_lib_name "gemm_multi_d_staticlib_${name}_${datatype}_${layout}") - # Collect the $ expressions - - set(obj_exprs) - foreach(objlib IN LISTS sub_intermediate_libs) - list(APPEND obj_exprs $) - endforeach() - - add_library(${intermediate_lib_name} STATIC ${obj_exprs}) - add_dependencies(${intermediate_lib_name} gemm_multi_d_gen_${datatype}_${layout}) - set_property(TARGET ${intermediate_lib_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS}) - #foreach(objlib IN LISTS sub_intermediate_libs) - # target_sources(${intermediate_lib_name} PRIVATE $) - #endforeach() - list(APPEND intermediate_libs ${intermediate_lib_name}) - endif() - - endforeach() - - # Interface library for instances - add_library(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE) - add_dependencies(gemm_multi_d_template_instances_${datatype}_${layout} gemm_multi_d_gen_${datatype}_${layout}) - target_link_libraries(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE ${intermediate_libs}) - target_include_directories(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE - ${CMAKE_CURRENT_LIST_DIR} - "${working_path}" - ) - set_target_properties(gemm_multi_d_template_instances_${datatype}_${layout} PROPERTIES LINKER_LANGUAGE CXX) - - # Host API interface library - add_library(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE) - target_link_libraries(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE gemm_multi_d_template_instances_${datatype}_${layout}) - target_include_directories(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE - ${CMAKE_CURRENT_LIST_DIR} - "${working_path}" - ) - - - - # Executable per datatype - set(exec_name "benchmark_gemm_multi_d_${datatype}_${layout}") - add_executable(${exec_name} benchmark_gemm_multi_d.cpp) - set_property(TARGET ${exec_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS}) - target_link_libraries(${exec_name} PRIVATE gemm_multi_d_host_api_${datatype}_${layout}) - target_compile_options(${exec_name} PRIVATE - -Wno-undefined-func-template - -Wno-float-equal - --offload-compress - ) + else() + message(FATAL_ERROR "Kernel list file not found") + endif() endfunction() -# Process each datatype in isolation -foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE) - foreach(l IN LISTS GEMM_MULTI_D_LAYOUT) - build_gemm_multi_d_for_datatype_layout(${dt} ${l}) - endforeach() +# Main build logic - Only individual builds supported +message(VERBOSE "=== Starting Tile Engine GEMM Multi D Configuration ===") +message(VERBOSE "GEMM_MULTI_D_DATATYPE: ${GEMM_MULTI_D_DATATYPE}") +message(VERBOSE "GEMM_MULTI_D_LAYOUT: ${GEMM_MULTI_D_LAYOUT}") +message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + +# Filter GPU targets to only gfx90a, gfx942, gfx950 +set(GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL "") +set(DESIRED_TARGETS "gfx90a;gfx942;gfx950") + +foreach(target IN LISTS SUPPORTED_GPU_TARGETS) + if(target IN_LIST DESIRED_TARGETS) + list(APPEND GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL ${target}) + message(VERBOSE " Adding GPU target: ${target}") + endif() endforeach() + +# Skip build if no matching targets found +if(NOT GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL) + message(WARNING "Skipping Tile Engine GEMM Multi D build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") +else() + message(VERBOSE "Building individual GEMM Multi D targets for GPU targets: ${GEMM_MULTI_D_GPU_TARGETS_INDIVIDUAL}") + + # Enable parallel compilation optimizations + # Set up job pools for better parallel compilation control + set_property(GLOBAL PROPERTY JOB_POOLS + compile_heavy=4 # Limit heavy compilations to prevent OOM + compile_normal=16 # Allow more parallel normal compilations + ) + + # Enable compiler cache if available and explicitly requested + # Disabled by default due to permission issues in CI environments + if(ENABLE_CCACHE_GEMM_MULTI_D) + find_program(CCACHE_PROGRAM ccache) + if(CCACHE_PROGRAM) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) + message(VERBOSE "Using ccache for faster compilation") + else() + message(WARNING "ccache requested but not found") + endif() + else() + message(VERBOSE "ccache disabled for GEMM Multi D ops (use -DENABLE_CCACHE_GEMM_MULTI_D=ON to enable)") + endif() + + # Create master collection targets + add_custom_target(benchmark_gemm_multi_d_all) + + # Create datatype collection targets + foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE) + add_custom_target(benchmark_gemm_multi_d_${dt}) + endforeach() + + # Create layout collection targets + foreach(l IN LISTS GEMM_MULTI_D_LAYOUT) + add_custom_target(benchmark_gemm_multi_d_${l}) + endforeach() + + # Create combined collection targets + foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE) + foreach(l IN LISTS GEMM_MULTI_D_LAYOUT) + add_custom_target(benchmark_gemm_multi_d_${dt}_${l}) + endforeach() + endforeach() + + # Create trait-based collection targets + # These are common trait components used across all GEMM Multi D kernels + set(GEMM_MULTI_D_PIPELINES "mem;compv3;compv4") + set(GEMM_MULTI_D_EPILOGUES "default;cshuffle") + set(GEMM_MULTI_D_SCHEDULERS "intrawave;interwave") + + foreach(pipeline IN LISTS GEMM_MULTI_D_PIPELINES) + add_custom_target(benchmark_gemm_multi_d_${pipeline}_pipeline) + endforeach() + + foreach(epilogue IN LISTS GEMM_MULTI_D_EPILOGUES) + add_custom_target(benchmark_gemm_multi_d_${epilogue}_epilogue) + endforeach() + + foreach(scheduler IN LISTS GEMM_MULTI_D_SCHEDULERS) + add_custom_target(benchmark_gemm_multi_d_${scheduler}_scheduler) + endforeach() + + # Build individual targets for each datatype/layout combination + foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE) + foreach(l IN LISTS GEMM_MULTI_D_LAYOUT) + build_individual_gemm_multi_d_targets(${dt} ${l}) + endforeach() + endforeach() +endif() diff --git a/tile_engine/ops/gemm_multi_d/README.md b/tile_engine/ops/gemm_multi_d/README.md deleted file mode 100644 index 66f0ed80af..0000000000 --- a/tile_engine/ops/gemm_multi_d/README.md +++ /dev/null @@ -1,110 +0,0 @@ - -CK Tile Engine for GEMM Multi D is used to generate and run GEMM kernels with different combinations of BlockTile sizes, WarpTile sizes, WarpTile mapping for all valid pipelines, schedulers and epilogues while able to give custom datatype and Layout selections - -# Kernel Configurations - -# User Specific -Users can specify custom kernel configurations such as tile size, warp size, padding, pipeline, scheduler, and epilogue in the config file. This allows building only for selected configurations, significantly reducing build time. -For reference please see `./configs/user_provided_config.json`. - -# Default -The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark. For reference please see in `./configs/default_config.json` - -If user does not provide kernel configuration, the tile engine uses default kernel configuration to generate kernel instances and benchmark. - -## Build Instructions -``` bash -# in the root of composable kernel create build directory -mkdir build && cd build -# build composable kernel -# replace [Arch] with the appropriate architecture or leave blank and -# replace [Datatype] in comma separated datatypes string (possible datatypes are [fp16]) -# replace [Layout1;Layout2;...] in comma separated datatypes string (possible layouts are [rcr, rrr, crr, ccr]) -# replace "mul" with either of mul,add,passthrough for Elementwise function as Multiply, Add or Passthrough respectively. If this is not specified it is considered as mul by default. -../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_MULTI_D_DATATYPE="[Datatype]" -DGEMM_MULTI_D_LAYOUT="[Layout1;Layout2]" -DGEMM_MULTI_D_ELEMENTWISE_FUNCTION="mul" -# generate different executable for each passed datatype -make benchmark_gemm_multi_d_[Datatype]_[Layout1] -j -make benchmark_gemm_multi_d_[Datatype]_[Layout2] -j -``` -`benchmark_gemm_multi_d_[Datatype]_[Layout]` will be located in the `./bin/` directory. - -`benchmark_gemm_multi_d_[Datatype]_[Layout]` must be rebuilt everytime if configuration file is modified. - -``` bash -rm -rf tile_engine/ && make benchmark_gemm_multi_d_[Datatype]_[Layout] -j # rebuild -``` - -## For eaxmple build for gfx942 for datatype with rcr layout -``` bash -mkdir build && cd build -../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_MULTI_D_DATATYPE="fp16" -DGEMM_MULTI_D_LAYOUT="rcrr" -make benchmark_gemm_multi_d_fp16_rcrr -j - -## benchmark_gemm inputs -``` - -m The value for m dimension. Default is 3840. - -n The value for n dimension. Default is 4096. - -k The value for k dimension. Default is 2048. - -stride_a The stride value for tensor A. Default is 0. - -stride_b The stride value for tensor B. Default is 0. - -stride_ds The stride value for tensor Ds. Default is 0. - -stride_e The stride value for tensor E. Default is 0. - -split_k The split value for k dimension. Default is 1. - -verify The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 for validation on GPU. Default is 1, validation on CPU, as validation on GPU is not supported. - -log Wether output kernel instance information or not. Possible values are true or false. Default is false. - -warmup The number of iterations before benchmark the kernel. Default is 50. - -repeat The number of iterations to benchmark the kernel. Default is 100. - -timer Whether if the timer is gpu timer or not. Possible values are false or true. Default is true. - -init The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 for constant(1). Default is 0, random. - -flush_cache To flush cache, possible values are true or false. Default is false. - -rotating_count Number of iterations to rotate the cache. Default is 5. - -metric Metric with which to measure kernel performance. Set to 0 for latency, 1 for tflops, or 2 for bandwidth. Default is 0, latency. - -csv_filename The filename of benchmark result. Default is gemm_multi_d_kernel. - -pipeline The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3. - -scheduler The type of scheduler. Possible values are intrawave. Default is intrawave. - -epilogue The type of epilogue. Possible values are cshuffle or default. Default is cshuffle. - -pad_m Whether pad or not in m direction. Possible values are true or false. Default is false. - -pad_n Whether pad or not in n direction. Possible values are true or false. Default is false. - -pad_k Whether pad or not in k direction. Possible values are true or false. Default is false. - -Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in user_provided_config.json -``` -Note: In `./configs/user_provided_config.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above. - -## Example - -The following JSON file specifies parameters used to generate and build GEMM kernels across all possible combinations of pipelines, schedulers, epilogues with different tile and warp sizes. - -```json -{ - /// other parameters /// - - "tile_m": { - "values": [256] - }, - "tile_n": { - "values": [256] - }, - "tile_k": { - "values": [64, 32] - }, - - /// other parameters /// - - "pipeline": { - "values": ["compv3", "compv4", "mem"] - }, - "scheduler": { - "values": ["intrawave", "interwave"] - }, - "epilogue": { - "values": ["cshuffle"] - } -} -``` - -At runtime, a specific subset of the generated kernels can be selected using command-line arguments. -``` bash -./bin/benchmark_gemm_multi_d_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=cshuffle -``` -The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and cshuffle epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings. diff --git a/tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.cpp b/tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.cpp deleted file mode 100644 index 764a295809..0000000000 --- a/tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.cpp +++ /dev/null @@ -1,73 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include - -#include "benchmark_gemm_multi_d.hpp" -#include "gemm_multi_d_profiler.hpp" - -void benchmark_gemm_multi_d(const ck_tile::ArgParser& arg_parser) -{ - GemmMultiDProblem gemm_multi_d_problem{arg_parser.get_int("split_k"), - arg_parser.get_int("m"), - arg_parser.get_int("n"), - arg_parser.get_int("k"), - arg_parser.get_int("stride_a"), - arg_parser.get_int("stride_b"), - arg_parser.get_int("stride_ds"), - arg_parser.get_int("stride_ds"), - arg_parser.get_int("stride_e"), - DataTypeTraits::name, - DataTypeTraits::name, - DataTypeTraits::name, - DataTypeTraits::name, - DataTypeTraits::name, - DataTypeTraits::name, - ALayout::name, - BLayout::name, - D0Layout::name, - D1Layout::name, - ELayout::name}; - - Setting setting{arg_parser.get_int("warmup"), - arg_parser.get_int("repeat"), - arg_parser.get_bool("timer"), - arg_parser.get_int("verify"), - arg_parser.get_int("init"), - arg_parser.get_bool("log"), - arg_parser.get_str("csv_filename"), - arg_parser.get_bool("flush_cache"), - arg_parser.get_int("rotating_count")}; - - auto& profiler = GemmMultiDProfiler::instance(setting); - - try - { - auto kernel_func = get_kernel_func_by_trait(arg_parser); - profiler.benchmark(gemm_multi_d_problem, kernel_func); - profiler.select_best_instance(static_cast(arg_parser.get_int("metric"))); - } - catch(const std::exception& e) - { - std::cerr << "Benchmark failed: " << e.what() << std::endl; - } -} - -int main(int argc, char* argv[]) -{ - try - { - auto [result, parser] = create_args(argc, argv); - if(!result) - return EXIT_FAILURE; - benchmark_gemm_multi_d(parser); - return 0; - } - catch(const std::exception& e) - { - std::cerr << "Error: " << e.what() << "\n"; - return EXIT_FAILURE; - } -} diff --git a/tile_engine/ops/gemm_multi_d/configs/custom_ci_config.json b/tile_engine/ops/gemm_multi_d/configs/custom_ci_config.json deleted file mode 100644 index cd638d9af0..0000000000 --- a/tile_engine/ops/gemm_multi_d/configs/custom_ci_config.json +++ /dev/null @@ -1,80 +0,0 @@ -{ - "tile_config": { - "tile_m": { - "values": [ - 256 ] - }, - "tile_n": { - "values": [ - 128 - ] - }, - "tile_k": { - "values": [ - 32 - ] - }, - "warp_m": { - "values": [ - 2 - ] - }, - "warp_n": { - "values": [ - 2 - ] - }, - "warp_k": { - "values": [ - 1 - ] - }, - "warp_tile_m": { - "values": [ - 16 - ] - }, - "warp_tile_n": { - "values": [ - 16 - ] - }, - "warp_tile_k": { - "values": [ - 16 - ] - } - }, - "trait_config": { - "pipeline": { - "values": [ - "compv3" - ] - }, - "scheduler": { - "values": [ - "intrawave" - ] - }, - "epilogue": { - "values": [ - "cshuffle" - ] - }, - "pad_m": { - "values": [ - false - ] - }, - "pad_n": { - "values": [ - false - ] - }, - "pad_k": { - "values": [ - false - ] - } - } -} \ No newline at end of file diff --git a/tile_engine/ops/gemm_multi_d/configs/default_config.json b/tile_engine/ops/gemm_multi_d/configs/default_config.json index 6d1afa4425..2447428158 100644 --- a/tile_engine/ops/gemm_multi_d/configs/default_config.json +++ b/tile_engine/ops/gemm_multi_d/configs/default_config.json @@ -1,84 +1,104 @@ { - "tile_config": { - "tile_m": { - "values": [ - 256 - ] + "tile_config": { + "tile_m": { + "max": 256, + "min": 64, + "step": 64 + }, + "tile_n": { + "max": 256, + "min": 64, + "step": 64 + }, + "tile_k": { + "max": 256, + "min": 64, + "step": 64 + }, + "warp_m": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_n": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 4, + 16, + 32 + ] + }, + "warp_tile_n": { + "values": [ + 16, + 32, + 64 + ] + }, + "warp_tile_k": { + "values": [ + 8, + 16, + 32, + 64, + 128 + ] + } }, - "tile_n": { - "values": [ - 128 - ] + "trait_config": { + "pipeline": { + "values": [ + "compv3", + "compv4", + "mem" + ] + }, + "scheduler": { + "values": [ + "intrawave", + "interwave" + ] + }, + "epilogue": { + "values": [ + "cshuffle", + "default" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + }, + "persistent": { + "values": [ + false, + true + ] + } }, - "tile_k": { - "values": [ - 32 - ] - }, - "warp_m": { - "values": [ - 2 - ] - }, - "warp_n": { - "values": [ - 2 - ] - }, - "warp_k": { - "values": [ - 1 - ] - }, - "warp_tile_m": { - "values": [ - 16 - ] - }, - "warp_tile_n": { - "values": [ - 16 - ] - }, - "warp_tile_k": { - "values": [ - 16 - ] - } - }, - "trait_config": { - "pipeline": { - "values": [ - "compv3", - "compv4", - "mem" - ] - }, - "scheduler": { - "values": [ - "intrawave", - "interwave" - ] - }, - "epilogue": { - "values": [ - "cshuffle" - ] - }, - "pad_m": { - "values": [ - false - ] - }, - "pad_n": { - "values": [ - false - ] - }, - "pad_k": { - "values": [ - false - ] - } - } -} \ No newline at end of file + "k_block_per_cu": 1 +} diff --git a/tile_engine/ops/gemm_multi_d/configs/user_provided_config.json b/tile_engine/ops/gemm_multi_d/configs/user_provided_config.json index 243d858fe5..40a7dda6cc 100644 --- a/tile_engine/ops/gemm_multi_d/configs/user_provided_config.json +++ b/tile_engine/ops/gemm_multi_d/configs/user_provided_config.json @@ -2,12 +2,12 @@ "tile_config": { "tile_m": { "values": [ - 256 + 64 ] }, "tile_n": { "values": [ - 256 + 192 ] }, "tile_k": { @@ -42,24 +42,24 @@ }, "warp_tile_k": { "values": [ - 16 + 8 ] } }, "trait_config": { "pipeline": { "values": [ - "compv3" + "compv4" ] }, "scheduler": { "values": [ - "intrawave" + "intrawave" ] }, "epilogue": { "values": [ - "cshuffle" + "cshuffle" ] }, "pad_m": { @@ -76,6 +76,12 @@ "values": [ false ] + }, + "persistent": { + "values": [ + true + ] } - } + }, + "k_block_per_cu": 1 } \ No newline at end of file diff --git a/tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp similarity index 78% rename from tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.hpp rename to tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp index f52d69e374..53dcdb5e1f 100644 --- a/tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.hpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp @@ -7,80 +7,14 @@ #include #include #include +#include -#include "gemm_multi_d_host_api.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "gemm_multi_d_common.hpp" -struct GemmMultiDProblem -{ - int split_k_; - int m_, n_, k_; - int stride_a_, stride_b_, stride_d0_, stride_d1_, stride_e_; - - std::string dtype_a_, dtype_b_, dtype_d0_, dtype_d1_, dtype_acc_, dtype_e_; - std::string layout_a_, layout_b_, layout_d0_, layout_d1_, layout_e_; - - friend std::ostream& operator<<(std::ostream& os, const GemmMultiDProblem& problem) - { - os << "{\n" - << " \"split_k\":" << problem.split_k_ << ",\n" - << " \"m\":" << problem.m_ << ",\n" - << " \"n\":" << problem.n_ << ",\n" - << " \"k\":" << problem.k_ << ",\n" - << " \"stride_a\":" << problem.stride_a_ << ",\n" - << " \"stride_b\":" << problem.stride_b_ << ",\n" - << " \"stride_d0\":" << problem.stride_d0_ << ",\n" - << " \"stride_d1\":" << problem.stride_d1_ << ",\n" - << " \"stride_e\":" << problem.stride_e_ << ",\n" - << " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n" - << " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n" - << " \"dtype_d0\":\"" << problem.dtype_d0_ << "\",\n" - << " \"dtype_d1\":\"" << problem.dtype_d1_ << "\",\n" - << " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n" - << " \"dtype_e\":\"" << problem.dtype_e_ << "\",\n" - << " \"layout_a\":\"" << problem.layout_a_ << "\",\n" - << " \"layout_b\":\"" << problem.layout_b_ << "\",\n" - << " \"layout_d0\":\"" << problem.layout_d0_ << "\",\n" - << " \"layout_d1\":\"" << problem.layout_d1_ << "\",\n" - << " \"layout_e\":\"" << problem.layout_e_ << "\"\n" - << "}"; - return os; - } -}; - -struct Setting -{ - int n_warmup_; - int n_repeat_; - bool is_gpu_timer_; - int verify_; - int init_method_; - bool log_; - std::string csv_filename_; - bool flush_cache_; - int rotating_count_; -}; - -// @brief Function to get the kernel output with reference implementation on CPU -void gemm_multi_d_host_reference(int verify, - ck_tile::HostTensor& a_m_k, - ck_tile::HostTensor& b_k_n, - ck_tile::HostTensor& d0_m_n, - ck_tile::HostTensor& d1_m_n, - ck_tile::HostTensor& e_m_n_host_result) -{ - if(verify > 0) - { - // Currently supporting on CPU verification for Gemm Multi D - // e_m_n_host_result.SetZero(); - ck_tile::reference_gemm_multiple_d( - a_m_k, b_k_n, {d0_m_n, d1_m_n}, e_m_n_host_result); - } -} +// Data types and Layouts are defined by the generated kernel headers +// No hardcoded type definitions here to avoid conflicts enum class Metric { @@ -100,6 +34,43 @@ inline constexpr auto get_metric_name(Metric m) } } +struct GemmMultiDProblem +{ + int split_k_; + int m_, n_, k_; + int stride_a_, stride_b_, stride_d0_, stride_d1_, stride_c_; + + std::string dtype_a_, dtype_b_, dtype_d0_, dtype_d1_, dtype_acc_, dtype_c_; + std::string layout_a_, layout_b_, layout_d0_, layout_d1_, layout_c_; + + friend std::ostream& operator<<(std::ostream& os, const GemmMultiDProblem& problem) + { + os << "{\n" + << " \"split_k\":" << problem.split_k_ << ",\n" + << " \"m\":" << problem.m_ << ",\n" + << " \"n\":" << problem.n_ << ",\n" + << " \"k\":" << problem.k_ << ",\n" + << " \"stride_a\":" << problem.stride_a_ << ",\n" + << " \"stride_b\":" << problem.stride_b_ << ",\n" + << " \"stride_d0\":" << problem.stride_d0_ << ",\n" + << " \"stride_d1\":" << problem.stride_d1_ << ",\n" + << " \"stride_c\":" << problem.stride_c_ << ",\n" + << " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n" + << " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n" + << " \"dtype_d0\":\"" << problem.dtype_d0_ << "\",\n" + << " \"dtype_d1\":\"" << problem.dtype_d1_ << "\",\n" + << " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n" + << " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n" + << " \"layout_a\":\"" << problem.layout_a_ << "\",\n" + << " \"layout_b\":\"" << problem.layout_b_ << "\",\n" + << " \"layout_d0\":\"" << problem.layout_d0_ << "\",\n" + << " \"layout_d1\":\"" << problem.layout_d1_ << "\",\n" + << " \"layout_c\":\"" << problem.layout_c_ << "\"" << "\n" + << "}"; + return os; + } +}; + struct PerformanceResult { double latency_; @@ -143,15 +114,28 @@ struct KernelInstance friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj) { os << "{\n" - << " \"name\": \"" << "{\n" - << obj.name_ << "\n}" << "\",\n" - << " \"problem\": \"" << obj.problem_ << "\",\n" + << " \"name\": \"" << obj.name_ << "\",\n" + << " \"problem\": " << obj.problem_ << ",\n" << " \"perf_result\": " << obj.perf_result_ << "\n" << "}"; return os; } }; +struct Setting +{ + int n_warmup_; + int n_repeat_; + bool is_gpu_timer_; + int verify_; + int init_method_; + bool log_; + std::string csv_filename_; + bool flush_cache_; + int rotating_count_; + bool json_output_; +}; + inline std::string get_rocm_version() { std::ifstream version_file("/opt/rocm/.info/version"); @@ -164,6 +148,11 @@ inline std::string get_rocm_version() return "Unknown"; } +template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) @@ -175,17 +164,17 @@ auto calculate_rtol_atol(const ck_tile::index_t K, std::conditional_t; // Calculate thresholds - const auto rtol = ck_tile::get_relative_threshold( + const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); - const auto atol = ck_tile::get_absolute_threshold( + const auto atol = ck_tile::get_absolute_threshold( max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); // Calculate error due to split_k accumulation const auto rtol_split_k = - ck_tile::get_relative_threshold(kbatch); + ck_tile::get_relative_threshold(kbatch); - const auto atol_split_k = ck_tile::get_absolute_threshold( + const auto atol_split_k = ck_tile::get_absolute_threshold( max_accumulated_value, kbatch); // Use higher threshold @@ -195,16 +184,19 @@ auto calculate_rtol_atol(const ck_tile::index_t K, /// @brief Function to compare the results of the device and host computations bool compare(std::string instanceName, ck_tile::index_t K, - ck_tile::HostTensor& e_m_n_dev_result, - ck_tile::HostTensor& e_m_n_host_result) + ck_tile::index_t kbatch, + ck_tile::HostTensor& c_m_n_dev_result, + ck_tile::HostTensor& c_m_n_host_result) { const float max_accumulated_value = - *std::max_element(e_m_n_host_result.mData.begin(), e_m_n_host_result.mData.end()); + *std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end()); - const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value); + const auto rtol_atol = + calculate_rtol_atol( + K, kbatch, max_accumulated_value); - bool pass = ck_tile::check_err(e_m_n_dev_result, - e_m_n_host_result, + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_result, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); @@ -216,3 +208,25 @@ bool compare(std::string instanceName, return pass; } + +/// @brief Function to get the kernel output with reference implementation on CPU/GPU +void gemm_multi_d_host_reference(int verify, + ck_tile::HostTensor& a_m_k, + ck_tile::HostTensor& b_k_n, + ck_tile::HostTensor& d0_m_n, + ck_tile::HostTensor& d1_m_n, + ck_tile::HostTensor& c_m_n_host_result) +{ + if(verify > 0) + { + // Currently supporting on CPU verification for Gemm Multi D + // e_m_n_host_result.SetZero(); + ck_tile::reference_gemm_multiple_d( + a_m_k, b_k_n, {d0_m_n, d1_m_n}, c_m_n_host_result); + } +} diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py new file mode 100755 index 0000000000..fb81b9c2c2 --- /dev/null +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py @@ -0,0 +1,683 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +import sys +import json +import subprocess +import argparse +import csv +import time +from pathlib import Path +from typing import List, Dict, Tuple, Optional + + +class GemmMultiDBenchmark: + def __init__(self, build_dir: str, verbose: bool = False): + self.build_dir = Path(build_dir) + self.verbose = verbose + self.results = [] + + def discover_kernels(self) -> List[Path]: + """Find all benchmark_gemm_multi_d_* executables in the build directory""" + bin_dir = self.build_dir / "bin" + if not bin_dir.exists(): + print(f"Error: Binary directory {bin_dir} does not exist") + return [] + + kernels = list(bin_dir.glob("benchmark_gemm_multi_d_*")) + if self.verbose: + print(f"Found {len(kernels)} kernel executables") + for k in kernels: + print(f" - {k.name}") + return kernels + + def extract_kernel_info(self, kernel_path: Path) -> Dict[str, str]: + """Extract comprehensive kernel information from filename""" + name = kernel_path.stem + + # Initialize with basic info + info = { + "executable": str(kernel_path), + "name": name, + "data_type": "unknown", + "layout": "unknown", + "pipeline": "unknown", + "scheduler": "unknown", + "epilogue": "unknown", + } + + # Parse the kernel name pattern: + # benchmark_gemm_multi_d_fp16_rcr_mem_default_intrawave_False_False_False_False_False_256x256x32_2x2x1_4x64x16 + parts = name.split("_") + + if len(parts) >= 5: + # Extract data type (3rd part after benchmark_gemm_) + info["data_type"] = parts[4] if len(parts) > 4 else "unknown" + + # Extract layout (4th part) + info["layout"] = parts[5] if len(parts) > 5 else "unknown" + + # Extract pipeline (5th part) + info["pipeline"] = parts[6] if len(parts) > 6 else "unknown" + + # Extract epilogue (6th part) + info["epilogue"] = parts[7] if len(parts) > 7 else "unknown" + + # Extract scheduler (7th part) + info["scheduler"] = parts[8] if len(parts) > 8 else "unknown" + + # Extract detailed configuration from the end of the name + config_info = self.parse_detailed_config(name) + info.update(config_info) + + # Generate config ID + info["config_id"] = self.generate_config_id(info) + + return info + + def parse_detailed_config(self, kernel_name: str) -> Dict: + """Parse detailed configuration from kernel name""" + config = { + "tile_sizes": {"tile_m": 0, "tile_n": 0, "tile_k": 0}, + "warp_config": {"warp_m": 0, "warp_n": 0, "warp_k": 0}, + "warp_tile": {"warp_tile_m": 0, "warp_tile_n": 0, "warp_tile_k": 0}, + "optimization_flags": { + "pad_m": False, + "pad_n": False, + "pad_k": False, + "persistent": False, + }, + } + + # Split by underscore and look for patterns + parts = kernel_name.split("_") + + # Look for boolean flags (sequence of True/False values) + bool_sequence = [] + for i, part in enumerate(parts): + if part in ["True", "False"]: + bool_sequence.append(part == "True") + # Continue collecting consecutive boolean values + j = i + 1 + while j < len(parts) and parts[j] in ["True", "False"]: + bool_sequence.append(parts[j] == "True") + j += 1 + break + + # Assign boolean flags if we found them + # Order: pad_m, pad_n, pad_k, persistent (4 flags total) + if len(bool_sequence) >= 4: + config["optimization_flags"]["pad_m"] = bool_sequence[0] + config["optimization_flags"]["pad_n"] = bool_sequence[1] + config["optimization_flags"]["pad_k"] = bool_sequence[2] + config["optimization_flags"]["persistent"] = bool_sequence[3] + + # Look for tile size patterns (e.g., 256x256x32_2x2x1_4x64x16) + # The pattern is: tile_sizes_warp_config_warp_tile + dimension_groups = [] + for part in parts: + if "x" in part and len(part.split("x")) == 3: + try: + dims = [int(x) for x in part.split("x")] + if all(d > 0 for d in dims): + dimension_groups.append(dims) + except ValueError: + continue + + # Assign dimensions based on order and magnitude + if len(dimension_groups) >= 3: + # Sort by magnitude to identify: largest=tile_sizes, smallest=warp_config, middle=warp_tile + sorted_groups = sorted(dimension_groups, key=max, reverse=True) + + # Largest dimensions = tile sizes + config["tile_sizes"]["tile_m"] = sorted_groups[0][0] + config["tile_sizes"]["tile_n"] = sorted_groups[0][1] + config["tile_sizes"]["tile_k"] = sorted_groups[0][2] + + # Smallest dimensions = warp config + config["warp_config"]["warp_m"] = sorted_groups[2][0] + config["warp_config"]["warp_n"] = sorted_groups[2][1] + config["warp_config"]["warp_k"] = sorted_groups[2][2] + + # Middle dimensions = warp tile + config["warp_tile"]["warp_tile_m"] = sorted_groups[1][0] + config["warp_tile"]["warp_tile_n"] = sorted_groups[1][1] + config["warp_tile"]["warp_tile_k"] = sorted_groups[1][2] + elif len(dimension_groups) == 2: + # If only 2 groups, assign based on magnitude + sorted_groups = sorted(dimension_groups, key=max, reverse=True) + + # Larger = tile sizes + config["tile_sizes"]["tile_m"] = sorted_groups[0][0] + config["tile_sizes"]["tile_n"] = sorted_groups[0][1] + config["tile_sizes"]["tile_k"] = sorted_groups[0][2] + + # Smaller = warp config + config["warp_config"]["warp_m"] = sorted_groups[1][0] + config["warp_config"]["warp_n"] = sorted_groups[1][1] + config["warp_config"]["warp_k"] = sorted_groups[1][2] + elif len(dimension_groups) == 1: + # Only one group - assume it's tile sizes + config["tile_sizes"]["tile_m"] = dimension_groups[0][0] + config["tile_sizes"]["tile_n"] = dimension_groups[0][1] + config["tile_sizes"]["tile_k"] = dimension_groups[0][2] + + return config + + def generate_config_id(self, info: Dict) -> str: + """Generate a compact config ID from kernel info""" + # Create a compact identifier + parts = [ + info.get("data_type", "unk"), + info.get("layout", "unk"), + info.get("pipeline", "unk"), + info.get("scheduler", "unk"), + ] + + # Add tile configuration if available + tile_sizes = info.get("tile_sizes", {}) + if tile_sizes.get("tile_m", 0) > 0: + tile_str = ( + f"{tile_sizes['tile_m']}x{tile_sizes['tile_n']}x{tile_sizes['tile_k']}" + ) + parts.append(tile_str) + + # Add warp config if available + warp_config = info.get("warp_config", {}) + if warp_config.get("warp_m", 0) > 0: + warp_str = f"w{warp_config['warp_m']}x{warp_config['warp_n']}x{warp_config['warp_k']}" + parts.append(warp_str) + + # Add warp tile if available + warp_tile = info.get("warp_tile", {}) + if warp_tile.get("warp_tile_m", 0) > 0: + warp_tile_str = f"wt{warp_tile['warp_tile_m']}x{warp_tile['warp_tile_n']}x{warp_tile['warp_tile_k']}" + parts.append(warp_tile_str) + + return "_".join(parts) + + def run_kernel(self, kernel_path: Path, params: Dict[str, str]) -> Optional[Dict]: + """Run a single kernel with given parameters and save output to individual JSON file""" + # Create results directory + results_dir = self.build_dir / "results" + results_dir.mkdir(exist_ok=True) + + # Generate unique JSON filename for this kernel + json_file = results_dir / f"{kernel_path.stem}.json" + + cmd = [str(kernel_path)] + + # Add parameters + for key, value in params.items(): + cmd.append(f"-{key}={value}") + + # Add JSON output flag for clean JSON output + cmd.append("-json_output=true") + + if self.verbose: + print(f"Running: {' '.join(cmd)}") + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + + if result.returncode != 0: + print(f"Error running {kernel_path.name}: {result.stderr}") + return None + + # Save raw output to individual JSON file + output = result.stdout.strip() + if output: + with open(json_file, "w") as f: + f.write(output) + + # Parse the JSON file + return self.parse_json_file(json_file) + else: + print(f"No output from {kernel_path.name}") + return None + + except subprocess.TimeoutExpired: + print(f"Timeout running {kernel_path.name}") + return None + except Exception as e: + print(f"Error running {kernel_path.name}: {e}") + return None + + def parse_json_file(self, json_file: Path) -> Optional[Dict]: + """Parse JSON data from individual kernel output file""" + try: + with open(json_file, "r") as f: + content = f.read().strip() + + # Parse the JSON directly since executables produce clean JSON + data = json.loads(content) + + # Return the complete JSON data as-is, just add some convenience fields + result = data.copy() + if "perf_result" in data: + perf = data["perf_result"] + # Add convenience fields for backward compatibility + result["time_ms"] = perf.get("latency(ms)", 0) + result["tflops"] = perf.get("tflops(TFlops)", 0) + result["bandwidth_gb_s"] = perf.get("bandwidth(GB/s)", 0) + + return result + + except json.JSONDecodeError as e: + if self.verbose: + print(f"Failed to parse JSON from {json_file}: {e}") + return None + except Exception as e: + if self.verbose: + print(f"Error reading JSON file {json_file}: {e}") + return None + + def benchmark_problem_size( + self, + kernels: List[Path], + m: int, + n: int, + k: int, + split_k: int = 1, + verify: int = 0, + warmup: int = 50, + repeat: int = 100, + flush_cache: bool = True, + rotating_count: int = 1000, + ) -> List[Dict]: + """Benchmark all kernels for a specific problem size""" + results = [] + + params = { + "m": m, + "n": n, + "k": k, + "split_k": split_k, + "verify": verify, + "warmup": warmup, + "repeat": repeat, + "flush_cache": str(flush_cache).lower(), + "rotating_count": rotating_count, + } + + print(f"\nBenchmarking M={m}, N={n}, K={k}, split_k={split_k}") + + for kernel_path in kernels: + kernel_info = self.extract_kernel_info(kernel_path) + result = self.run_kernel(kernel_path, params) + + if result: + # Create new structured result format + structured_result = { + "name": kernel_info["name"], # Add name field for compatibility + "config_id": kernel_info["config_id"], + "problem": result.get("problem", {}), + "perf_result": result.get("perf_result", {}), + "config": { + "data_type": kernel_info["data_type"], + "layout": kernel_info["layout"], + "pipeline": kernel_info["pipeline"], + "scheduler": kernel_info["scheduler"], + "epilogue": kernel_info["epilogue"], + "tile_sizes": kernel_info.get("tile_sizes", {}), + "warp_config": kernel_info.get("warp_config", {}), + "warp_tile": kernel_info.get("warp_tile", {}), + "optimization_flags": kernel_info.get("optimization_flags", {}), + }, + "executable": kernel_info["executable"], + # Keep backward compatibility fields + "time_ms": result.get("time_ms", 0), + "tflops": result.get("tflops", 0), + "bandwidth_gb_s": result.get("bandwidth_gb_s", 0), + } + + results.append(structured_result) + + if self.verbose: + print( + f" {kernel_info['config_id']}: {structured_result['tflops']:.2f} TFLOPS, {structured_result['bandwidth_gb_s']:.2f} GB/s, {structured_result['time_ms']:.2f}ms" + ) + + return results + + def find_best_kernel( + self, results: List[Dict], metric: str = "tflops" + ) -> Optional[Dict]: + """Find the best performing kernel based on metric""" + if not results: + return None + + if metric == "tflops": + return max(results, key=lambda x: x.get("tflops", 0)) + elif metric == "time_ms": + return min(results, key=lambda x: x.get("time_ms", float("inf"))) + elif metric == "bandwidth_gb_s": + return max(results, key=lambda x: x.get("bandwidth_gb_s", 0)) + else: + raise ValueError(f"Unknown metric: {metric}") + + def benchmark_sweep( + self, + problem_sizes: List[Tuple[int, int, int]], + split_k_values: List[int] = [1], + verify: bool = False, + warmup: int = 50, + repeat: int = 100, + flush_cache: bool = True, + rotating_count: int = 1000, + ) -> Dict: + """Run comprehensive benchmark sweep""" + kernels = self.discover_kernels() + if not kernels: + print("No kernels found!") + return {} + + all_results = [] + best_kernels = {} + + for m, n, k in problem_sizes: + for split_k in split_k_values: + results = self.benchmark_problem_size( + kernels, + m, + n, + k, + split_k, + verify=2 if verify else 0, + warmup=warmup, + repeat=repeat, + flush_cache=flush_cache, + rotating_count=rotating_count, + ) + + all_results.extend(results) + + # Find best kernel for this configuration + best = self.find_best_kernel(results) + if best: + key = f"m{m}_n{n}_k{k}_splitk{split_k}" + best_kernels[key] = best + print( + f"Best for {key}: {best['name']} ({best['tflops']:.2f} TFLOPS, {best['bandwidth_gb_s']:.2f} GB/s, {best['time_ms']:.2f}ms)" + ) + + self.results = all_results + return best_kernels + + def export_csv(self, filename: str): + """Export all results to CSV""" + if not self.results: + print("No results to export") + return + + # Get all unique keys from results + all_keys = set() + for result in self.results: + all_keys.update(result.keys()) + + # Sort keys for consistent output + fieldnames = sorted(all_keys) + + with open(filename, "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(self.results) + + print(f"Results exported to {filename}") + + def export_best_kernels(self, best_kernels: Dict, filename: str): + """Export best kernel selections to file""" + with open(filename, "w") as f: + f.write("# Best kernel selections\n") + f.write( + "# Format: problem_size -> kernel_name (TFLOPS, bandwidth, latency)\n\n" + ) + + for key, kernel in sorted(best_kernels.items()): + f.write( + f"{key}: {kernel['name']} ({kernel['tflops']:.2f} TFLOPS, {kernel['bandwidth_gb_s']:.2f} GB/s, {kernel['time_ms']:.2f}ms)\n" + ) + + print(f"Best kernels exported to {filename}") + + def export_json(self, filename: str, best_kernels: Dict = None): + """Export all results and best kernels to JSON with comprehensive metadata""" + from datetime import datetime + + # Calculate comprehensive summary statistics for all metrics + successful_results = [r for r in self.results if r.get("tflops", 0) > 0] + + tflops_values = [r.get("tflops", 0) for r in successful_results] + bandwidth_values = [r.get("bandwidth_gb_s", 0) for r in successful_results] + latency_values = [ + r.get("time_ms", 0) for r in successful_results if r.get("time_ms", 0) > 0 + ] + + # Performance breakdown by kernel type + pipeline_stats = {} + scheduler_stats = {} + data_type_stats = {} + + for result in successful_results: + # Get config info from the new structure + config = result.get("config", {}) + + # Pipeline statistics + pipeline = config.get("pipeline", "unknown") + if pipeline not in pipeline_stats: + pipeline_stats[pipeline] = { + "count": 0, + "avg_tflops": 0, + "best_tflops": 0, + } + pipeline_stats[pipeline]["count"] += 1 + pipeline_stats[pipeline]["best_tflops"] = max( + pipeline_stats[pipeline]["best_tflops"], result.get("tflops", 0) + ) + + # Scheduler statistics + scheduler = config.get("scheduler", "unknown") + if scheduler not in scheduler_stats: + scheduler_stats[scheduler] = { + "count": 0, + "avg_tflops": 0, + "best_tflops": 0, + } + scheduler_stats[scheduler]["count"] += 1 + scheduler_stats[scheduler]["best_tflops"] = max( + scheduler_stats[scheduler]["best_tflops"], result.get("tflops", 0) + ) + + # Data type statistics + data_type = config.get("data_type", "unknown") + if data_type not in data_type_stats: + data_type_stats[data_type] = { + "count": 0, + "avg_tflops": 0, + "best_tflops": 0, + } + data_type_stats[data_type]["count"] += 1 + data_type_stats[data_type]["best_tflops"] = max( + data_type_stats[data_type]["best_tflops"], result.get("tflops", 0) + ) + + # Calculate averages for breakdown stats + for stats_dict, field_name in [ + (pipeline_stats, "pipeline"), + (scheduler_stats, "scheduler"), + (data_type_stats, "data_type"), + ]: + for key in stats_dict: + relevant_results = [ + r + for r in successful_results + if r.get("config", {}).get(field_name, "unknown") == key + ] + if relevant_results: + stats_dict[key]["avg_tflops"] = sum( + r.get("tflops", 0) for r in relevant_results + ) / len(relevant_results) + + output_data = { + "benchmark_metadata": { + "timestamp": datetime.now().isoformat(), + "total_kernels_tested": len(self.results), + "unique_kernels": len( + set(r.get("name", "unknown") for r in self.results) + ), + "successful_runs": len(successful_results), + "failed_runs": len(self.results) - len(successful_results), + }, + "performance_summary": { + "tflops_stats": { + "best": max(tflops_values, default=0), + "average": sum(tflops_values) / len(tflops_values) + if tflops_values + else 0, + "min": min(tflops_values, default=0), + "median": sorted(tflops_values)[len(tflops_values) // 2] + if tflops_values + else 0, + }, + "bandwidth_stats": { + "best_gb_s": max(bandwidth_values, default=0), + "average_gb_s": sum(bandwidth_values) / len(bandwidth_values) + if bandwidth_values + else 0, + "min_gb_s": min(bandwidth_values, default=0), + "median_gb_s": sorted(bandwidth_values)[len(bandwidth_values) // 2] + if bandwidth_values + else 0, + }, + "latency_stats": { + "best_ms": min(latency_values, default=0), + "average_ms": sum(latency_values) / len(latency_values) + if latency_values + else 0, + "max_ms": max(latency_values, default=0), + "median_ms": sorted(latency_values)[len(latency_values) // 2] + if latency_values + else 0, + }, + "kernel_type_breakdown": { + "by_pipeline": pipeline_stats, + "by_scheduler": scheduler_stats, + "by_data_type": data_type_stats, + }, + "total_problem_configurations": len(best_kernels) + if best_kernels + else 0, + }, + "kernel_results": self.results, + "best_kernels_by_problem": best_kernels or {}, + } + + with open(filename, "w") as f: + json.dump(output_data, f, indent=2) + + print(f"JSON results exported to {filename}") + print(f" - Total kernels: {len(self.results)}") + print(f" - Successful runs: {len(successful_results)}") + print(f" - Best TFLOPS: {max(tflops_values, default=0):.2f}") + print(f" - Best bandwidth: {max(bandwidth_values, default=0):.2f} GB/s") + print(f" - Best latency: {min(latency_values, default=0):.2f}ms") + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Multi D Kernel Benchmarking Tool" + ) + parser.add_argument( + "build_dir", help="Build directory containing kernel executables" + ) + parser.add_argument( + "--problem-sizes", + nargs="+", + default=["1024,1024,1024", "2048,2048,2048", "4096,4096,4096"], + help="Problem sizes as M,N,K tuples", + ) + parser.add_argument( + "--split-k", nargs="+", type=int, default=[1], help="Split-K values to test" + ) + parser.add_argument("--verify", action="store_true", help="Enable verification") + parser.add_argument( + "--csv", + default="gemm_multi_d_benchmark_results.csv", + help="CSV output filename", + ) + parser.add_argument( + "--best", default="best_kernels.txt", help="Best kernels output filename" + ) + parser.add_argument("--verbose", action="store_true", help="Verbose output") + parser.add_argument( + "--warmup", + type=int, + default=50, + help="Number of warmup iterations (default: 50)", + ) + parser.add_argument( + "--repeat", + type=int, + default=100, + help="Number of benchmark iterations (default: 100)", + ) + parser.add_argument( + "--flush-cache", + action="store_true", + default=True, + help="Enable cache flushing (default: True)", + ) + parser.add_argument( + "--rotating-count", + type=int, + default=1000, + help="Number of iterations to rotate cache (default: 1000)", + ) + parser.add_argument("--json", help="JSON output filename (optional)") + + args = parser.parse_args() + + # Parse problem sizes + problem_sizes = [] + for size_str in args.problem_sizes: + try: + m, n, k = map(int, size_str.split(",")) + problem_sizes.append((m, n, k)) + except ValueError: + print(f"Invalid problem size: {size_str}") + return 1 + + # Create benchmark instance + benchmark = GemmMultiDBenchmark(args.build_dir, verbose=args.verbose) + + # Run benchmark sweep + print("Starting GEMM Multi D kernel benchmark sweep...") + start_time = time.time() + + best_kernels = benchmark.benchmark_sweep( + problem_sizes=problem_sizes, + split_k_values=args.split_k, + verify=args.verify, + warmup=args.warmup, + repeat=args.repeat, + flush_cache=args.flush_cache, + rotating_count=args.rotating_count, + ) + + elapsed_time = time.time() - start_time + print(f"\nBenchmark completed in {elapsed_time:.2f} seconds") + + # Export results + benchmark.export_csv(args.csv) + benchmark.export_best_kernels(best_kernels, args.best) + + # Export JSON if requested + if args.json: + benchmark.export_json(args.json, best_kernels) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp new file mode 100644 index 0000000000..032a625354 --- /dev/null +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "gemm_multi_d_profiler.hpp" +#include "gemm_multi_d_common.hpp" + +// The kernel header is included via the compile command line with -include flag +// It defines SelectedKernel struct and KERNEL_NAME +// DataTypeTraits are now defined in gemm_multi_d_common.hpp + +// Create argument parser +inline auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.") + .insert("n", "4096", "The value for n dimension. Default is 4096.") + .insert("k", "2048", "The value for k dimension. Default is 2048.") + .insert("stride_a", "0", "The stride value for tensor A. Default is 0.") + .insert("stride_b", "0", "The stride value for tensor B. Default is 0.") + .insert("stride_ds", "0", "The stride value for tensor Ds . Default is 0.") + .insert("stride_c", "0", "The stride value for tensor C. Default is 0.") + .insert("split_k", "1", "The split value for k dimension. Default is 1.") + .insert("verify", + "1", + "for validation on GPU. Default is 1, validation on CPU, as validation on GPU is " + "not supported.") + .insert("log", + "false", + "Whether output kernel instance information or not. Possible values are true or " + "false. Default is false") + .insert( + "warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.") + .insert( + "repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.") + .insert("timer", + "true", + "Whether if the timer is gpu timer or not. Possible values are false or true. " + "Default is true.") + .insert("init", + "0", + "The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 " + "for constant(1). Default is 0, random.") + .insert("flush_cache", + "true", + "To flush cache, possible values are true or false. " + "Default is false.") + .insert("rotating_count", "1000", "number of iterations to rotate the cache. default is 5.") + .insert("metric", + "0", + "Metric with which to measure kernel performance. Set to 0 for latency, 1 for " + "tflops, or 2 for bandwidth. Default is 0, latency.") + .insert("csv_filename", + "", + "The filename of benchmark result. Default is empty (no CSV output).") + .insert("structured_sparsity", + "false", + "Whether use sparsity kernel or not. Possible values are true or false. Default is " + "false") + .insert("json_output", + "false", + "Whether to output results in JSON format only. Possible values are true or false. " + "Default is " + "false"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +void benchmark_single(const ck_tile::ArgParser& arg_parser) +{ + // Use DataTypeTraits to get the actual type names from the generated header + // The generated header defines ADataType, BDataType, AccDataType, CDataType + std::string dtype_a = DataTypeTraits::name; + std::string dtype_b = DataTypeTraits::name; + std::string dtype_acc = DataTypeTraits::name; + std::string dtype_c = DataTypeTraits::name; + std::string dtype_d0 = DataTypeTraits::name; + std::string dtype_d1 = DataTypeTraits::name; + + // Layout names from the layout types + std::string layout_a = ALayout::name; + std::string layout_b = BLayout::name; + std::string layout_c = CLayout::name; + std::string layout_d0 = D0Layout::name; + std::string layout_d1 = D1Layout::name; + + // Create GemmMultiDProblem struct + GemmMultiDProblem gemm_multi_d_problem{arg_parser.get_int("split_k"), + arg_parser.get_int("m"), + arg_parser.get_int("n"), + arg_parser.get_int("k"), + arg_parser.get_int("stride_a"), + arg_parser.get_int("stride_b"), + arg_parser.get_int("stride_ds"), + arg_parser.get_int("stride_ds"), + arg_parser.get_int("stride_c"), + dtype_a, + dtype_b, + dtype_d0, + dtype_d1, + dtype_acc, + dtype_c, + layout_a, + layout_b, + layout_d0, + layout_d1, + layout_c}; + + // Create Setting struct + Setting setting{arg_parser.get_int("warmup"), + arg_parser.get_int("repeat"), + arg_parser.get_bool("timer"), + arg_parser.get_int("verify"), + arg_parser.get_int("init"), + arg_parser.get_bool("log"), + arg_parser.get_str("csv_filename"), + arg_parser.get_bool("flush_cache"), + arg_parser.get_int("rotating_count"), + arg_parser.get_bool("json_output")}; + + // Get the profiler instance + auto& profiler = GemmMultiDProfiler::instance(setting); + + try + { + // Create a lambda that wraps the kernel launch + auto kernel_func = [](const ck_tile::GemmMultiDHostArgs& args, + const ck_tile::stream_config& stream) { + return SelectedKernel::launch(args, stream); + }; + + // Benchmark the kernel + profiler.benchmark(gemm_multi_d_problem, kernel_func); + + // Select best instance based on metric + profiler.select_best_instance(static_cast(arg_parser.get_int("metric"))); + } + catch(const std::exception& e) + { + std::cerr << "Benchmark failed: " << e.what() << std::endl; + } +} + +int main(int argc, char* argv[]) +{ + try + { + auto [result, parser] = create_args(argc, argv); + if(!result) + return EXIT_FAILURE; + + benchmark_single(parser); + return 0; + } + catch(const std::exception& e) + { + std::cerr << "Error: " << e.what() << "\n"; + return EXIT_FAILURE; + } +} diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py deleted file mode 100644 index 32ed616d75..0000000000 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py +++ /dev/null @@ -1,196 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -# -*- coding: utf-8 -*- - -""" -Mappings and utility functions for kernel code generation. -""" - -DATA_TYPE_MAP = { - "fp32": "float", - "fp16": "ck_tile::half_t", - "bf16": "ck_tile::bf16_t", - "int8": "ck_tile::int8_t", - "fp8": "ck_tile::fp8_t", - "bf8": "ck_tile::bf8_t", - "int4": "ck_tile::pk_int4_t", - "int32": "ck_tile::int32_t", -} - -LAYOUT_MAP = { - "r": "ck_tile::tensor_layout::gemm::RowMajor", - "c": "ck_tile::tensor_layout::gemm::ColumnMajor", -} - - -# TODO THIS IS NOT SUPPORTED FOR MULTI D AS OF NOW -# DEFAULT_EPILOGUE = """ -# using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< -# ck_tile::DefaultGemm2DEpilogueProblem>; -# """ - -CSHUFFLE_EPILOGUE = """ - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; -""" - -PIPELINE_MAP = { - "mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"], - "compv3": [ - "ck_tile::BaseGemmPipelineAgBgCrCompV3", - "ck_tile::GemmPipelineAgBgCrCompV3", - ], - "compv4": [ - "ck_tile::BaseGemmPipelineAgBgCrCompV4", - "ck_tile::GemmPipelineAgBgCrCompV4", - ], -} - -SCHEDULER_MAP = { - "interwave": "ck_tile::GemmPipelineScheduler::Interwave", - "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", -} - -# EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE} - -EPILOGUE_MAP = {"cshuffle": CSHUFFLE_EPILOGUE} - - -def BOOL_MAP(b_): - return {True: "true", False: "false"}[bool(b_)] - - -# Can add some more supported combinations -warp_tile_supported_combinations = { - "gfx90a": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], - "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], - }, - "gfx942": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], - "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], - "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], - }, - "gfx950": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "fp8_fp8_fp16": [ - [32, 32, 16], - [32, 32, 32], - [16, 16, 32], - [16, 16, 64], - [16, 16, 128], - [32, 32, 64], - ], - "bf8_bf8_fp16": [ - [32, 32, 16], - [32, 32, 32], - [16, 16, 64], - [16, 16, 32], - [16, 16, 128], - [32, 32, 64], - ], - }, -} - -# Remove some unsupported combinations -trait_unsupported_combinations = { - ("compv3", "cshuffle", "interwave"), - ("compv3", "default", "interwave"), - ("compv4", "cshuffle", "interwave"), - ("compv4", "default", "interwave"), -} - - -ELEMENT_SIZE_MAP = { - "fp16": 2, - "bf16": 2, - "int8": 1, - "fp8": 1, - "bf8": 1, - "int4": 0.5, - "int32": 4, -} - - -def element_size(data_type: str) -> float: - """Calculate the size (in bytes) of a single element for given data type.""" - data_type = data_type.lower() - if data_type not in ELEMENT_SIZE_MAP: - raise ValueError(f"Unsupported data type: {data_type}") - return ELEMENT_SIZE_MAP[data_type] diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp new file mode 100644 index 0000000000..4732f2a1ba --- /dev/null +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/pk_int4.hpp" + +//[TODO] This can be moved to commons +// DataTypeTraits for all supported types +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +// Helper function to determine if a layout is row-major +template +constexpr auto is_row_major(Layout) +{ + return ck_tile::bool_constant>{}; +} + +// Structure to hold kernel traits for dispatcher +struct KernelTraits +{ + std::string pipeline; // compv3, compv4, mem + std::string scheduler; // intrawave, interwave + std::string epilogue; // cshuffle, default + bool pad_m; + bool pad_n; + bool pad_k; + bool persistent; + + // Constructor with defaults + KernelTraits() + : pipeline("compv3"), + scheduler("intrawave"), + epilogue("cshuffle"), + pad_m(false), + pad_n(false), + pad_k(false), + persistent(false) + { + } +}; diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_config.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_config.py deleted file mode 100644 index e5a879158f..0000000000 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_config.py +++ /dev/null @@ -1,250 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -# -*- coding: utf-8 -*- - -""" -Handles loading, parsing, and validation of JSON and Argument configuration parameters. -""" - -from pathlib import Path -from dataclasses import dataclass -from typing import List, Optional, Union, Type -import json - - -@dataclass -class EnumConfigParam: - """Represents an enumeration-type configuration parameter""" - - values: List[Union[int, str, bool]] - - -@dataclass -class RangeConfigParam: - """Represents a numeric range-type configuration parameter""" - - min: int - max: int - step: int - exclude: Optional[List[int]] - - def generate_candidates(self) -> List[int]: - """Generates valid candidates after applying range constraints""" - - if self.min > self.max: - raise ValueError(f"Invalid range: min({self.min}) > max({self.max})") - if self.step <= 0: - raise ValueError(f"Step must be positive, got {self.step}") - - candidates = list(range(self.min, self.max + 1, self.step)) - - if hasattr(self, "exclude") and self.exclude: - if not isinstance(self.exclude, list): - raise TypeError("exclude must be list type") - exclude_set = set(self.exclude) - candidates = [x for x in candidates if x not in exclude_set] - - if not candidates: - raise ValueError( - f"No valid candidates for range [{self.min}-{self.max}] " - f"with step {self.step} and excludes {self.exclude}" - ) - - return candidates - - -@dataclass -class DataType: - """Configuration class for data type parameter.""" - - a_datatype: str - b_datatype: str - e_datatype: str - d0_datatype: str - d1_datatype: str - ds_datatype: List[str] - - -@dataclass -class Layout: - """Configuration class for Layout parameter.""" - - a_layout: str - b_layout: str - e_layout: str - d0_layout: str - d1_layout: str - ds_layout: List[str] - - -@dataclass -class ArgumentConfig: - """Configuration class for Argument parameter.""" - - datatypes: DataType - layouts: Layout - function_name: str - - @classmethod - def from_args( - cls: Type["ArgumentConfig"], - datatype: str, - layout: str, - elementwise_function: str, - ) -> "ArgumentConfig": - """configuration loader with validation controls""" - - datatypes = DataType( - a_datatype=datatype, - b_datatype=datatype, - e_datatype=datatype, - d0_datatype=datatype, - d1_datatype=datatype, - ds_datatype=[datatype, datatype], - ) - - layout_parts = layout.lower() - assert len(layout_parts) == 4, ( - f"Invalid layout string: {layout} (must be 4 characters like 'rcrr' where r stands for row major and c stands for column major)" - ) - assert layout_parts[0] in ("r", "c"), ( - f"Invalid matrix_a layout: {layout_parts[0]} (must be 'r' for row major or or 'c' for column major)" - ) - assert layout_parts[1] in ("r", "c"), ( - f"Invalid matrix_b layout: {layout_parts[1]} (must be 'r' for row major or or 'c' for column major)" - ) - assert layout_parts[2] == "r", ( - f"Invalid matrix_e layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)" - ) - assert layout_parts[3] == "r", ( - f"Invalid D dimension layout: {layout_parts[3]} (must be 'r' only as currently we are supporting only row major)" - ) - - layouts = Layout( - a_layout=layout[0], - b_layout=layout[1], - e_layout=layout[2], - d0_layout=layout[3], - d1_layout=layout[3], - ds_layout=[layout[3], layout[3]], - ) - # Elementwise function name validation - valid_functions = ["mul", "add", "passthrough"] - if elementwise_function not in valid_functions: - raise ValueError( - f"Invalid elementwise function: {elementwise_function}. " - f"Valid options are: {', '.join(valid_functions)}" - ) - - # Set the function name based on the elementwise function - if elementwise_function == "mul": - function_name = "MultiDMultiply" - elif elementwise_function == "add": - function_name = "MultiDAdd" - elif elementwise_function == "passthrough": - function_name = "PassThrough" # TODO Change this - - return cls(datatypes=datatypes, layouts=layouts, function_name=function_name) - - -@dataclass -class TileConfig: - """Configuration class for tile parameter.""" - - tile_m: Union[EnumConfigParam, RangeConfigParam] - tile_n: Union[EnumConfigParam, RangeConfigParam] - tile_k: Union[EnumConfigParam, RangeConfigParam] - - warp_m: Union[EnumConfigParam, RangeConfigParam] - warp_n: Union[EnumConfigParam, RangeConfigParam] - warp_k: Union[EnumConfigParam, RangeConfigParam] - - warp_tile_m: Union[EnumConfigParam, RangeConfigParam] - warp_tile_n: Union[EnumConfigParam, RangeConfigParam] - warp_tile_k: Union[EnumConfigParam, RangeConfigParam] - - -@dataclass -class TraitConfig: - """Configuration class for kernel traits.""" - - pipeline: EnumConfigParam - scheduler: EnumConfigParam - epilogue: EnumConfigParam - pad_m: EnumConfigParam - pad_n: EnumConfigParam - pad_k: EnumConfigParam - - -@dataclass -class JsonConfig: - """Configuration class for JSON parameter.""" - - tile_config: TileConfig - trait_config: TraitConfig - - @classmethod - def from_json(cls: Type["JsonConfig"], filepath: str) -> "JsonConfig": - """JSON configuration loader with validation controls""" - config_path = Path(filepath) - - try: - if not config_path.exists(): - raise FileNotFoundError(f"Config file {filepath} not found") - - with config_path.open("r") as f: - config_dict = json.load(f) - - # Parse tile config - def create_param(param_dict): - if "values" in param_dict: - return EnumConfigParam(values=param_dict["values"]) - else: - return RangeConfigParam( - min=param_dict["min"], - max=param_dict["max"], - step=param_dict["step"], - exclude=param_dict.get("exclude", []), - ) - - tile_config = TileConfig( - tile_m=create_param(config_dict["tile_config"]["tile_m"]), - tile_n=create_param(config_dict["tile_config"]["tile_n"]), - tile_k=create_param(config_dict["tile_config"]["tile_k"]), - warp_m=create_param(config_dict["tile_config"]["warp_m"]), - warp_n=create_param(config_dict["tile_config"]["warp_n"]), - warp_k=create_param(config_dict["tile_config"]["warp_k"]), - warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]), - warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]), - warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]), - ) - - # Parse trait config - trait_config = TraitConfig( - pipeline=EnumConfigParam( - values=config_dict["trait_config"]["pipeline"]["values"] - ), - scheduler=EnumConfigParam( - values=config_dict["trait_config"]["scheduler"]["values"] - ), - epilogue=EnumConfigParam( - values=config_dict["trait_config"]["epilogue"]["values"] - ), - pad_m=EnumConfigParam( - values=config_dict["trait_config"]["pad_m"]["values"] - ), - pad_n=EnumConfigParam( - values=config_dict["trait_config"]["pad_n"]["values"] - ), - pad_k=EnumConfigParam( - values=config_dict["trait_config"]["pad_k"]["values"] - ), - ) - - return cls(tile_config=tile_config, trait_config=trait_config) - - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON format: {str(e)}") - except KeyError as e: - raise KeyError(f"Missing required configuration field: {str(e)}") diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_host_api.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_host_api.hpp deleted file mode 100644 index 41fddf30aa..0000000000 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_host_api.hpp +++ /dev/null @@ -1,164 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include - -#include "ck_tile/host.hpp" -#include "gemm_multi_d_dispatcher.hpp" -#include "gemm_multi_d_common.hpp" - -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - -template -static constexpr inline auto is_row_major(Layout layout_) -{ - return ck_tile::bool_constant, - ck_tile::tensor_layout::gemm::RowMajor>>{}; -} - -inline auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.") - .insert("n", "4096", "The value for n dimension. Default is 4096.") - .insert("k", "2048", "The value for k dimension. Default is 2048.") - .insert("stride_a", "0", "The stride value for tensor A. Default is 0.") - .insert("stride_b", "0", "The stride value for tensor B. Default is 0.") - .insert("stride_ds", "0", "The stride value for tensor Ds Default is 0.") - .insert("stride_e", "0", "The stride value for tensor E Default is 0.") - .insert("split_k", "1", "The split value for k dimension. Default is 1.") - .insert("verify", - "1", - "The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 " - "for validation on GPU. Default is 1, validation on CPU, as validation on GPU is " - "not supported.") - .insert("log", - "false", - "Wether output kernel instance information or not. Possible values are true or " - "false. Default is false") - .insert("warmup", - "50", - "The number of iterations before benchmarking the kernel. Default is 50.") - .insert("repeat", - "100", - "The number of iterations for benchmarking the kernel. Default is 100.") - .insert("timer", - "true", - "Indicates whether the timer is a GPU timer. Possible values are true or false. " - "Default is true.") - .insert("init", - "0", - "The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 " - "for constant(1). Default is 0, random.") - .insert("flush_cache", - "false", - "To flush cache, possible values are true or false. " - "Default is false.") - .insert("rotating_count", "5", "number of iterations to rotate the cache. default is 5.") - .insert("metric", - "0", - "Metric with which to measure kernel performance. Set to 0 for latency, 1 for " - "tflops, or 2 for bandwidth. Default is 0, latency.") - .insert("csv_filename", - "gemm_multi_d_kernel", - "The filename of benchmark result. Default is set to gemm_multi_d_kernel.") - .insert( - "pipeline", - "compv3", - "The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.") - .insert("scheduler", - "intrawave", - "The type of pipeline. Possible values are compv3, compv4 or mem. Default is " - "compv3.") - .insert( - "epilogue", - "cshuffle", - "The type of epilogue. Possible values are cshuffle or default. Default is cshuffle.") - .insert("pad_m", - "false", - "Whether pad or not in m direction. Possible values are true or false. Default is " - "false.") - .insert("pad_n", - "false", - "Whether pad or not in n direction. Possible values are true or false. Default is " - "false.") - .insert("pad_k", - "false", - "Whether pad or not in k direction. Possible values are true or false. Default is " - "false."); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -auto get_kernel_func_by_trait(const ck_tile::ArgParser& arg_parser) -{ - KernelTraits trait; - trait.pipeline = arg_parser.get_str("pipeline"); - trait.scheduler = arg_parser.get_str("scheduler"); - trait.epilogue = arg_parser.get_str("epilogue"); - trait.pad_m = arg_parser.get_bool("pad_m"); - trait.pad_n = arg_parser.get_bool("pad_n"); - trait.pad_k = arg_parser.get_bool("pad_k"); - - return GemmMultiDDispatcher::dispatch(trait); -} diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py old mode 100755 new mode 100644 index cc534565d9..3f7858f146 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py @@ -1,471 +1,558 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -# -*- coding: utf-8 -*- - -""" -generate kernel instances to speed up compilation -""" +#!/usr/bin/env python +import os +import json import argparse import itertools +import multiprocessing +import concurrent.futures from pathlib import Path -from typing import List, Optional -from gemm_multi_d_config import JsonConfig, ArgumentConfig, RangeConfigParam -from gemm_multi_d_codegen_utils import ( - DATA_TYPE_MAP, - LAYOUT_MAP, - PIPELINE_MAP, - SCHEDULER_MAP, - EPILOGUE_MAP, - BOOL_MAP, - warp_tile_supported_combinations, - trait_unsupported_combinations, - element_size, -) import logging +import importlib.util + + +def _import_validation_utils(): + """Import validation utilities from commons directory.""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + parent_dir = os.path.dirname(current_dir) + + # Load the module dynamically + spec = importlib.util.spec_from_file_location( + "validation_utils", os.path.join(parent_dir, "commons", "validation_utils.py") + ) + validation_utils = importlib.util.module_from_spec(spec) + spec.loader.exec_module(validation_utils) + + return validation_utils + + +# Import validation functions +_validation_utils = _import_validation_utils() +is_tile_config_valid = _validation_utils.is_tile_config_valid +is_trait_combination_valid = _validation_utils.is_trait_combination_valid +get_dtype_string = _validation_utils.get_dtype_string +get_abcd_layouts = _validation_utils.get_abcd_layouts logging.basicConfig(level=logging.INFO) -class GemmMultiDCodeGenerator: - """GEMM (General Matrix Multiplication) Multi D code generator.""" - +class GemmMultiDKernelBuilder: def __init__( self, - args: argparse.Namespace, - user_provided_config: Optional[JsonConfig] = None, + working_path, + gpu_target, + datatype, + layout, + elementwise_function, + config_json=None, ): - self.output_dir = Path(args.working_path) - self.output_dir.mkdir(parents=True, exist_ok=True) + self.working_path = Path(working_path) + self.gpu_target = gpu_target + self.datatype = datatype + self.layout = layout + self.elementwise_function = elementwise_function + self.config_json = config_json - self.gpu_target = args.gpu_target + # Create working directory if it doesn't exist + self.working_path.mkdir(parents=True, exist_ok=True) - if user_provided_config is not None: - self.config = user_provided_config - else: - config_path = ( - Path(__file__).resolve().parent / "configs" / "default_config.json" - ) - self.config = JsonConfig.from_json(config_path) + # Load configuration + if config_json and os.path.exists(config_json): + with open(config_json, "r") as f: + self.config = json.load(f) - self.args = ArgumentConfig.from_args( - args.datatype, args.layout, args.elementwise_function - ) + def write_kernel_list(self): + """Write kernel list to file for CMake to read (with comprehensive validation)""" + # Get configurations using comprehensive validation + tile_configs = self._get_tile_configs(fast_mode=False) + trait_combos = self._generate_trait_combinations() - self.valid_trait_names: List[str] = [] - self.valid_trait_tile_combinations: map[str, list[tuple[int]]] = {} + kernel_list = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + ( + pipeline, + epilogue, + scheduler, + pad_m, + pad_n, + pad_k, + persistent, + ) = trait_combo - def list_all_trait_names(self): - """List all possible kernel trait names into file.""" - w_p = Path(self.output_dir) - file_path = w_p / "gemm_multi_d_instance_blobs.txt" - self._generate_all_traits() - self._get_valid_trait_tile_combinations() - file_range_map = {} - # Write all file paths to the header file - files_listed = 0 - with file_path.open("w") as f: - # Core files - core_files = [ - "gemm_multi_d_common.hpp", - "gemm_multi_d_instances.hpp", - "gemm_multi_d_dispatcher.hpp", - ] - for core_file in core_files: - f.write(str(w_p / core_file) + "\n") - files_listed += 1 + # Create kernel name with proper boolean capitalization + kernel_name = f"gemm_multi_d_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" - # Trait header files - for trait in self.valid_trait_names: - trait_file = f"gemm_multi_d_{trait}.hpp" - f.write(str(w_p / trait_file) + "\n") - files_listed += 1 - file_name = set() - # Instance source files - for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): - start_idx = files_listed - for tile in tile_valid_params: - for ( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - _, - _, - _, - ) in tile: - instance_name = f"gemm_multi_d_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp" + # Create tile configuration string + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" - if instance_name not in file_name: - file_name.add(instance_name) - f.write(str(w_p / instance_name) + "\n") - files_listed += 1 + kernel_name += f"_{tile_str}" - file_range_map[trait] = (start_idx, files_listed) - - file_path = w_p / "gemm_multi_d_instance_blobs_range.txt" - with file_path.open("w") as f: - for name, ranges in file_range_map.items(): - start, last = ranges - f.write(name + " " + f"{start}" + " " + f"{last}" + "\n") - - def _generate_all_traits(self): - """Generate all possible kernel traits names.""" - params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"] - - # Generate all unique_combinations - _unique = set( - itertools.product( - *[getattr(self.config.trait_config, param).values for param in params] - ) - ) - - for combo in _unique: - pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo - current_combination = (pipeline, epilogue, scheduler) - - if current_combination not in trait_unsupported_combinations: - trait_name = ( - f"{pipeline}_{epilogue}_{scheduler}_" - f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}" + kernel_list.append( + { + "name": kernel_name, + "tile_config": tile_config, + "trait_combo": trait_combo, + } ) - self.valid_trait_names.append(trait_name) + + # Write kernel count + with open(self.working_path / "gemm_multi_d_kernel_count.txt", "w") as f: + f.write(str(len(kernel_list))) + + # Write kernel list + with open(self.working_path / "gemm_multi_d_kernel_list.txt", "w") as f: + for kernel in kernel_list: + # Format: kernel_name|tile_config|trait_combo + tile_config = kernel["tile_config"] + trait_combo = kernel["trait_combo"] + + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + trait_str = ( + f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_" + + "_".join(str(x) for x in trait_combo[3:]) + ) + + f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n") + + print(f"Listed {len(kernel_list)} kernel configurations") + + def _get_tile_configs(self, fast_mode=False): + """Get tile configurations for the current datatype and layout""" + tile_config = self.config["tile_config"] + + # Generate values in the config if default range is given + if tile_config.get("tile_m").get("values") is None: + tile_config.get("tile_m")["values"] = self._generate_values( + tile_config.get("tile_m").get("min"), + tile_config.get("tile_m").get("max"), + tile_config.get("tile_m").get("step"), + ) + if tile_config.get("tile_n").get("values") is None: + tile_config.get("tile_n")["values"] = self._generate_values( + tile_config.get("tile_n").get("min"), + tile_config.get("tile_n").get("max"), + tile_config.get("tile_n").get("step"), + ) + if tile_config.get("tile_k").get("values") is None: + tile_config.get("tile_k")["values"] = self._generate_values( + tile_config.get("tile_k").get("min"), + tile_config.get("tile_k").get("max"), + tile_config.get("tile_k").get("step"), + ) + + # Get all possible values for each parameter + tile_m_values = tile_config.get("tile_m").get("values") + tile_n_values = tile_config.get("tile_n").get("values") + tile_k_values = tile_config.get("tile_k").get("values") + warp_m_values = tile_config.get("warp_m").get("values") + warp_n_values = tile_config.get("warp_n").get("values") + warp_k_values = tile_config.get("warp_k").get("values") + warp_tile_m_values = tile_config.get("warp_tile_m").get("values") + warp_tile_n_values = tile_config.get("warp_tile_n").get("values") + warp_tile_k_values = tile_config.get("warp_tile_k").get("values") + + # Generate all combinations + configs = [] + for tile_m in tile_m_values: + for tile_n in tile_n_values: + for tile_k in tile_k_values: + for warp_m in warp_m_values: + for warp_n in warp_n_values: + for warp_k in warp_k_values: + for warp_tile_m in warp_tile_m_values: + for warp_tile_n in warp_tile_n_values: + for warp_tile_k in warp_tile_k_values: + # Validate configuration + if self._validate_tile_config( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + fast_mode=fast_mode, + ): + configs.append( + { + "tile_m": tile_m, + "tile_n": tile_n, + "tile_k": tile_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "warp_tile_m": warp_tile_m, + "warp_tile_n": warp_tile_n, + "warp_tile_k": warp_tile_k, + } + ) + return configs + + def _generate_values(self, min_val, max_val, step): + """Generate a list of values from min to max with the given step""" + values = [] + val = min_val + while val <= max_val: + values.append(val) + val += step + return values + + def _validate_tile_config( + self, + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + pipeline="compv4", # Default pipeline for validation + fast_mode=False, # Add fast mode option + ): + """Validate that tile configuration is reasonable""" + if fast_mode: + # Fast validation for listing - only basic sanity checks + if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: + return False + if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: + return False + if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: + return False + + # Basic divisibility check + if tile_m % (warp_m * warp_tile_m) != 0: + return False + if tile_n % (warp_n * warp_tile_n) != 0: + return False + if tile_k % (warp_k * warp_tile_k) != 0: + return False + + return True + else: + # Full validation for generation + # Determine data types for validation + a_datatype = self.datatype + b_datatype = self.datatype + c_datatype = self.datatype + + layout = self.layout + + # Special handling for certain data types + if self.datatype in ["fp8", "bf8"]: + c_datatype = "fp16" + + # Use the comprehensive validation function + return is_tile_config_valid( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + pipeline, + layout, + self.gpu_target, + ) + + def _generate_trait_combinations(self): + """Generate all combinations of traits""" + + trait_config = self.config["trait_config"] + + pipelines = trait_config.get("pipeline").get("values") + epilogues = trait_config.get("epilogue").get("values") + schedulers = trait_config.get("scheduler").get("values") + pad_m_values = trait_config.get("pad_m").get("values") + pad_n_values = trait_config.get("pad_n").get("values") + pad_k_values = trait_config.get("pad_k").get("values") + persistent_values = trait_config.get("persistent").get("values") + + all_combinations = list( + itertools.product( + pipelines, + epilogues, + schedulers, + pad_m_values, + pad_n_values, + pad_k_values, + persistent_values, + ) + ) + + # Filter out unsupported trait combinations + combinations = [] + for combo in all_combinations: + pipeline, epilogue, scheduler = combo[:3] + if is_trait_combination_valid(pipeline, epilogue, scheduler): + combinations.append(combo) else: - logging.debug(f"Invalid combination: {pipeline}-{epilogue}-{scheduler}") + logging.debug( + f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}" + ) + return combinations - def _get_valid_trait_tile_combinations(self): - def get_tile_value(tile_param): - return ( - tile_param.generate_candidates() - if isinstance(tile_param, RangeConfigParam) - else tile_param.values - ) + def _generate_kernel_instance( + self, tile_config, trait_combo, k_block_per_cu, is_header=True + ): + """Generate a single kernel instance""" + ( + pipeline, + epilogue, + scheduler, + pad_m, + pad_n, + pad_k, + persistent, + ) = trait_combo - tile_group = list( - itertools.product( - get_tile_value(self.config.tile_config.tile_m), - get_tile_value(self.config.tile_config.tile_n), - get_tile_value(self.config.tile_config.tile_k), - ) + # Create kernel name with proper boolean capitalization + kernel_name = f"gemm_multi_d_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" + + # Create tile configuration string + tile_str = ( + f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" ) - - warp_group = list( - itertools.product( - get_tile_value(self.config.tile_config.warp_m), - get_tile_value(self.config.tile_config.warp_n), - get_tile_value(self.config.tile_config.warp_k), - ) + tile_str += ( + f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" ) + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" - warp_tile_group = list( - itertools.product( - get_tile_value(self.config.tile_config.warp_tile_m), - get_tile_value(self.config.tile_config.warp_tile_n), - get_tile_value(self.config.tile_config.warp_tile_k), - ) - ) + kernel_name += f"_{tile_str}" - tile_params = { - t + w + wt for t in tile_group for w in warp_group for wt in warp_tile_group + # Map pipeline names to the correct pipeline implementation + pipeline_impl_map = { + "mem": "ck_tile::GemmPipelineAgBgCrMem", + "compv3": "ck_tile::GemmPipelineAgBgCrCompV3", + "compv4": "ck_tile::GemmPipelineAgBgCrCompV4", } - for trait in self.valid_trait_names: - tile_valid_params = [ - tile for tile in tile_params if self.is_tile_valid(tile, trait) - ] + # Map pipeline names to base pipeline for hot loop detection + base_pipeline_map = { + "mem": "ck_tile::BaseGemmPipelineAgBgCrMem", + "compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3", + "compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4", + } - if trait not in self.valid_trait_tile_combinations: - self.valid_trait_tile_combinations[trait] = [] - self.valid_trait_tile_combinations[trait].append(tile_valid_params) + # Map scheduler names to the correct enum values + scheduler_type_map = { + "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", + "interwave": "ck_tile::GemmPipelineScheduler::Interwave", + "default": "ck_tile::GemmPipelineScheduler::Default", + } - def is_tile_valid(self, tile: tuple, trait: str) -> bool: - """Check if the tile configuration is valid for the given trait.""" - ( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - ) = tile - pipeline, *_ = trait.split("_") + # Determine accumulator type based on datatype + acc_type = "float" - # Parameter validity check - invalid_params = [] - if (warp_m, warp_n, warp_k) not in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]: - invalid_params.append( - f"warp_m({warp_m}) * warp_n({warp_n}) * warp_k({warp_k})" - ) - if (warp_m * warp_tile_m) == 0: - invalid_params.append(f"warp_m({warp_m}) * warp_tile_m({warp_tile_m})") - if (warp_n * warp_tile_n) == 0: - invalid_params.append(f"warp_n({warp_n}) * warp_tile_n({warp_tile_n})") - if (warp_k * warp_tile_k) == 0: - invalid_params.append(f"warp_k({warp_k}) * warp_tile_k({warp_tile_k})") + # Determine output type + c_type = self.datatype + if self.datatype in ["fp8", "bf8"]: + c_type = "fp16" - if invalid_params: - logging.debug( - f"Trait: [{trait}], Invalid warp configuration: {', '.join(invalid_params)}. " - f"Parameter combination: warp=({warp_m},{warp_n},{warp_k}), " - f"warp_tile=({warp_tile_m},{warp_tile_n},{warp_tile_k})" - ) - return False - # Dimension alignment check - alignment_issues = [] - if tile_m % (warp_m * warp_tile_m) != 0: - alignment_issues.append( - f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}" - ) - if tile_n % (warp_n * warp_tile_n) != 0: - alignment_issues.append( - f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}" - ) - if tile_k % (warp_k * warp_tile_k) != 0: - alignment_issues.append( - f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}" - ) - - if alignment_issues: - logging.debug( - f"Trait: [{trait}], Dimension alignment failed: {', '.join(alignment_issues)}. " - f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by " - f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" - ) - return False - - # LDS capacity verification - matrix_a_size = (tile_m * tile_k) * element_size(self.args.datatypes.a_datatype) - - matrix_b_size = (tile_n * tile_k) * element_size(self.args.datatypes.b_datatype) - - total_tile_in_lds = matrix_a_size + matrix_b_size - - max_tile_size = 2**15 if pipeline == "compv4" else 2**16 - - if total_tile_in_lds > max_tile_size: - logging.debug( - f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > " - f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n" - f"- Matrix A ({self.config.problem.datatype_map['matrix_a']}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" - f"- Matrix B ({self.config.problem.datatype_map['matrix_b']}): {tile_n}x{tile_k} = {matrix_b_size:,}B" - ) - return False - - # Warp combination validation - warp_tile_key = f"{self.args.datatypes.a_datatype}_{self.args.datatypes.b_datatype}_{self.args.datatypes.e_datatype}" - - current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] - - gpu_name = self.gpu_target - - gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {}) - if not gpu_warp_tile_key: - logging.debug( - f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check." - ) - return False - - allowed_combinations = gpu_warp_tile_key.get(warp_tile_key, []) - if not allowed_combinations: - logging.debug( - f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check." - ) - return False - - if current_combination not in allowed_combinations: - logging.debug( - f"Trait: [{trait}], Invalid warp combination: {current_combination} not in allowed list. " - f"Valid combinations for data type '{warp_tile_key}': {allowed_combinations}" - ) - return False - - return True - - def generate_all_instance_files(self): - """Generate all kernel instances files.""" - self._generate_common_header_file() - self._generate_all_trait_files() - self._generate_dispatcher_file() - - def _generate_common_header_file(self): - """Generate common header file with datatypes and layout.""" - - acc_type = "float" # As we are currently supporting only fp16 - - content = f""" -#pragma once + # Determine layouts based on self.layout + a_layout, b_layout, c_layout, ds_layout = get_abcd_layouts(self.layout) + # Generate kernel instance code using the correct API + pragma_line = "#pragma once\n" if is_header else "" + instance_code = f"""// Generated kernel instance for {kernel_name} +{pragma_line} +#include +#include +#include #include "ck_tile/core.hpp" -#include "ck_tile/ops/common.hpp" -#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" - -// Data types -using ADataType = {DATA_TYPE_MAP[self.args.datatypes.a_datatype]}; -using BDataType = {DATA_TYPE_MAP[self.args.datatypes.b_datatype]}; -using AccDataType = {acc_type}; -using D0DataType = {DATA_TYPE_MAP[self.args.datatypes.d0_datatype]}; -using D1DataType = {DATA_TYPE_MAP[self.args.datatypes.d1_datatype]}; -using DsDataType = ck_tile::tuple; -using EDataType = {DATA_TYPE_MAP[self.args.datatypes.e_datatype]}; - - -// Layout configurations -using ALayout = {LAYOUT_MAP[self.args.layouts.a_layout]}; -using BLayout = {LAYOUT_MAP[self.args.layouts.b_layout]}; -using D0Layout = {LAYOUT_MAP[self.args.layouts.d0_layout]}; -using D1Layout = {LAYOUT_MAP[self.args.layouts.d1_layout]}; -using DsLayout = ck_tile::tuple; -using ELayout = {LAYOUT_MAP[self.args.layouts.e_layout]}; - -// Element-wise function for D -using ElementWiseFn = ck_tile::element_wise::{self.args.function_name}; - -""" - - (self.output_dir / "gemm_multi_d_common.hpp").write_text(content) - - def _generate_all_trait_files(self): - """Generate all kernel traits into files.""" - if not self.valid_trait_names: - self._generate_all_traits() - self._get_valid_trait_tile_combinations() - for trait in self.valid_trait_names: - self._generate_trait_file(trait) - self._generate_instantiation_source_files() - self._generate_common_instance_header_file() - - def _generate_trait_file(self, trait: str): - """Generate a trait with all tile/warp combinations.""" - pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = trait.split("_") - filename = f"gemm_multi_d_{trait}.hpp" - - content = f""" -#pragma once - -#include "gemm_multi_d_common.hpp" +#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" -namespace {trait} {{ -""" - # Add template struct with configuration - content += self._generate_kernel_struct( - pipeline, epilogue, scheduler, pad_m, pad_n, pad_k - ) +using ADataType = {get_dtype_string(self.datatype)}; +using BDataType = {get_dtype_string(self.datatype)}; +using AccDataType = {acc_type}; +using CDataType = {get_dtype_string(c_type)}; +using D0DataType = {get_dtype_string(self.datatype)}; +using D1DataType = {get_dtype_string(self.datatype)}; +using DsDataType = ck_tile::tuple; - content += f"\n}} // namespace {trait}\n" - (self.output_dir / filename).write_text(content) +using ALayout = {a_layout}; +using BLayout = {b_layout}; +using CLayout = {c_layout}; +using D0Layout = {ds_layout[0]}; +using D1Layout = {ds_layout[1]}; +using DsLayout = ck_tile::tuple; - def _generate_kernel_struct( - self, - pipeline: str, - epilogue: str, - scheduler: str, - pad_m: str, - pad_n: str, - pad_k: str, - ) -> str: - """Generate the code block of kernel struct""" - return f""" +using ElementWiseFn = ck_tile::element_wise::{self.elementwise_function}; -template -struct GemmKernelMultiD {{ - static constexpr bool kPadM = {pad_m}; - static constexpr bool kPadN = {pad_n}; - static constexpr bool kPadK = {pad_k}; +// Kernel name for display +constexpr const char* KERNEL_NAME = "{kernel_name}"; - static float launch(ck_tile::GemmMultiDHostArgs& args, const ck_tile::stream_config& stream) {{ - static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"}; - - static constexpr bool TransposeC = false; +// Wrapper for simplified launch interface +struct SelectedKernel {{ + // Tile configuration + static constexpr ck_tile::index_t BlockSize = 256; + static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]}; + static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]}; + static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]}; + static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]}; + static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]}; + static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]}; + static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]}; + static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]}; + static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]}; - static constexpr int kBlockPerCu = 1; - static constexpr ck_tile::index_t TileParitionerGroupNum = 8; - static constexpr ck_tile::index_t TileParitionerM01 = 4; + // Traits + static constexpr bool kPadM = {"true" if pad_m in [True, "true"] else "false"}; + static constexpr bool kPadN = {"true" if pad_n in [True, "true"] else "false"}; + static constexpr bool kPadK = {"true" if pad_k in [True, "true"] else "false"}; + + static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"}; + static constexpr bool TransposeC = false; - using GemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; + // Tile shape + using TileShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + // Tile partitioner + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; + + // Traits + using Traits = ck_tile::TileGemmTraits; + + // Pipeline problem + using GemmPipelineProblem = ck_tile::GemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + Traits>; + + // Base pipeline for hot loop detection + using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; - - using Traits = - ck_tile::TileGemmTraits; - - using GemmUniversalTraits = - ck_tile::TileGemmUniversalTraits; - - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - - using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}; - - const ck_tile::index_t k_grain = args.k_batch * TileK; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + static float launch(const ck_tile::GemmMultiDHostArgs& args, const ck_tile::stream_config& stream) {{ + const ck_tile::index_t k_grain = args.k_batch * TileK; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - + float ave_time{{0}}; const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{ constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = {SCHEDULER_MAP[scheduler]}; - constexpr auto memory_operation = memory_operation_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = {scheduler_type_map.get(scheduler)}; + [[maybe_unused]] constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + ck_tile::TileGemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v>; + + using GemmPipeline = {pipeline_impl_map.get(pipeline)}; + + // Epilogue +""" - using GemmPipeline = {PIPELINE_MAP[pipeline][1]}; - {EPILOGUE_MAP[epilogue]} - using Kernel = ck_tile::GemmKernelMultiD; - auto kargs = Kernel::MakeKernelArgs(args); + # Add epilogue configuration based on type + if epilogue == "cshuffle": + instance_code += """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< + ADataType, + BDataType, + DsDataType, + AccDataType, + CDataType, + DsLayout, + CLayout, + ElementWiseFn, + TilePartitioner::MPerBlock, // kM_ + TilePartitioner::NPerBlock, // kN_ + WarpPerBlock_M, // MWave_ + WarpPerBlock_N, // NWave_ + WarpTileM, // MPerXdl_ + WarpTileN, // NPerXdl_ + WarpTileK, // KPerXdl_ + TransposeC, // isCTransposed_ + memory_operation>; // MemoryOperation_ + + using GemmEpilogue = ck_tile::CShuffleEpilogue; +""" + else: # default epilogue + instance_code += """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< + ADataType, + BDataType, + DsDataType, + AccDataType, + CDataType, + DsLayout, + CLayout, + ElementWiseFn, + TilePartitioner::MPerBlock, // kM_ + TilePartitioner::NPerBlock, // kN_ + kPadM, + kPadN, + WarpTileM, // kMPerXdl_ + WarpTileN, // kNPerXdl_ + WarpTileK, // kKPerXdl_ + TransposeC>; // isCTransposed_ + + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue; +""" - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - {{ + instance_code += f""" + + // Kernel type + using GemmKernelMultiD = ck_tile::GemmKernelMultiD; + + // Make kernel arguments + auto kargs = GemmKernelMultiD::MakeKernelArgs(args); + + if (!GemmKernelMultiD::IsSupportedArgument(kargs)) {{ throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); }} - - if(stream.log_level_ > 0) - {{ - std::cout << "Launching kernel with args:" - << " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" - << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" - << std::endl; + + // Get grid and block sizes + const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = GemmKernelMultiD::BlockSize(); + + if(stream.log_level_ > 0) {{ + std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n' + << "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" + << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" + << std::endl; }} - - ave_time = ck_tile::launch_kernel(stream, - ck_tile::make_kernel( - Kernel{{}}, grids, blocks, 0, kargs)); - + + // Launch kernel + constexpr int kBlockPerCu = {k_block_per_cu}; + ave_time = ck_tile::launch_kernel( + stream, + ck_tile::make_kernel(GemmKernelMultiD{{}}, grids, blocks, 0, kargs)); + return ave_time; - }}; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{ @@ -483,279 +570,340 @@ struct GemmKernelMultiD {{ }}; BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; }} - - static std::string get_name() {{ - return std::string("gemm_multi_d_") + std::to_string(TileM) + "x" + std::to_string(TileN) + "x" + std::to_string(TileK) + - "_" + std::to_string(WarpM) + "x" + std::to_string(WarpN) + "x" + std::to_string(WarpK) + "_" + - std::to_string(WarpTileM) + "x" + std::to_string(WarpTileN) + "x" + std::to_string(WarpTileK) + "_" + - "{pad_m}" + "_" + - "{pad_n}" + "_" + - "{pad_k}" + "_" + - "{pipeline}" + "_" + - "{epilogue}" + "_" + - "{scheduler}"; - }} }}; """ + return kernel_name, instance_code - def _generate_instantiation_source_files(self): - """Generate kernel instance instantiation source files""" - tile_map = {} - for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): - for tile in tile_valid_params: - for ( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - ) in tile: - key = f"{tile_m}x{tile_n}x{tile_k}x{warp_m}x{warp_n}x{warp_k}" - value = f"{warp_tile_m}x{warp_tile_n}x{warp_tile_k}" - if key not in tile_map: - tile_map[key] = set() - tile_map[key].add(value) + def run(self, num_workers=None): + """Run the builder to generate individual kernel files""" + # Generate individual kernel files + self.generate_individual(num_workers) - files_listed = 0 - for trait, _ in self.valid_trait_tile_combinations.items(): - for block_tile, warp_tiles in tile_map.items(): - tile_m, tile_n, tile_k, warp_m, warp_n, warp_k = map( - int, block_tile.split("x") + def generate_individual(self, num_workers=None): + """Generate individual kernel files for separate compilation with parallel processing""" + if num_workers is None: + num_workers = min( + multiprocessing.cpu_count(), 8 + ) # Limit to avoid memory issues + + tile_configs = self._get_tile_configs() + trait_combos = self._generate_trait_combinations() + k_block_per_cu = self.config.get("k_block_per_cu") + if k_block_per_cu is None: + k_block_per_cu = 1 + + # Prepare work items for parallel processing + work_items = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + work_items.append( + ( + tile_config, + trait_combo, + k_block_per_cu, + self.working_path, + self.gpu_target, + self.datatype, + self.layout, + self.elementwise_function, + self.config_json, + ) ) - content = f""" -#include "gemm_multi_d_{trait}.hpp" - -""" - for warp_tile in warp_tiles: - warp_tile_m, warp_tile_n, warp_tile_k = map( - int, warp_tile.split("x") - ) - - files_listed = files_listed + 1 - content = ( - content - + f""" -template struct {trait}::GemmKernelMultiD<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}>;""" - ) - content += """ -""" - ( - self.output_dir - / f"gemm_multi_d_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp" - ).write_text(content) - print(f"Generated {files_listed} kernel instances in total.") - - def _generate_common_instance_header_file(self): - """Generate common instance header into file.""" - content = """ -#pragma once -""" - for trait in self.valid_trait_names: - content += f'#include "gemm_multi_d_{trait}.hpp"\n' - (self.output_dir / "gemm_multi_d_instances.hpp").write_text(content) - - def _generate_dispatcher_file(self): - """Generate the code block of dispatch mechanism.""" - content = """ -#pragma once - -#include -#include -#include - -#include "gemm_multi_d_common.hpp" -#include "gemm_multi_d_instances.hpp" - -/// @brief Defines the configuration parameters for a GEMM Multi D operation, enabling the selection of a -/// specific kernel instance based on the provided settings. -struct KernelTraits -{ - /// @brief The name of the pipeline. - std::string pipeline; - /// @brief The name of the scheduler (e.g., "intrawave", "interwave"). - std::string scheduler; - /// @brief The name of the epilogue (e.g., "cshuffle", "default"). - std::string epilogue; - /// @brief Indicates whether padding is applied to the M dimension. - bool pad_m; - /// @brief Indicates whether padding is applied to the N dimension. - bool pad_n; - /// @brief Indicates whether padding is applied to the K dimension. - bool pad_k; -}; - -struct GemmMultiDDispatcher { - static auto& get_kernel_map() { - // Use a static local variable - static std::unordered_map< - std::string, - std::vector(ck_tile::GemmMultiDHostArgs&, const ck_tile::stream_config&)>>> - kernel_map; - return kernel_map; - } - - static void init() { - auto& kernel_map = get_kernel_map(); - if(!kernel_map.empty()) return; - \n""" - - for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): - content += f""" kernel_map["{trait}"] = {{""" - for _, tile in enumerate(tile_valid_params): - for j in range(len(tile)): - ( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - ) = tile[j] - content += """[=](ck_tile::GemmMultiDHostArgs& args, const ck_tile::stream_config& stream) { """ - - content += f""" - return run_kernel<{trait}::GemmKernelMultiD<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}>>(args, stream);""" - - if j == len(tile) - 1: - content += """ - } """ - else: - content += """ - }, """ - content += """ - };\n """ - - content += """ } - - template - static std::tuple run_kernel(ck_tile::GemmMultiDHostArgs& args, const ck_tile::stream_config& stream) - { - std::string name = Kernel::get_name(); - float avg_time = Kernel::launch(args, stream); - - return std::make_tuple(name, avg_time); - } - - - static auto dispatch(const KernelTraits& trait) { - init(); - const std::string key = assemble_key(trait); - auto& kernel_map = get_kernel_map(); - if(auto it = kernel_map.find(key); it != kernel_map.end()) - { - return it->second; - } - throw std::runtime_error("No suitable kernel found: " + key); - } - -private: - static std::string assemble_key(const KernelTraits &trait) { - return std::string(trait.pipeline) + "_" + - trait.epilogue + "_" + - trait.scheduler + "_" + - (trait.pad_m ? "true" : "false") + "_" + - (trait.pad_n ? "true" : "false") + "_" + - (trait.pad_k ? "true" : "false"); - } -}; - -""" - (self.output_dir / "gemm_multi_d_dispatcher.hpp").write_text(content) - - -def do_list_blobs( - args: argparse.Namespace, user_provide_config: Optional[JsonConfig] = None -): - generator = GemmMultiDCodeGenerator(args, user_provide_config) - generator.list_all_trait_names() - - -def do_gen_blobs( - args: argparse.Namespace, user_provide_config: Optional[JsonConfig] = None -): - generator = GemmMultiDCodeGenerator(args, user_provide_config) - generator.generate_all_instance_files() - - -def main(args): - gemm_multi_d_config = JsonConfig.from_json(args.config_json) - - if args.list_blobs: - do_list_blobs(args, gemm_multi_d_config) - elif args.gen_blobs: - do_gen_blobs(args, gemm_multi_d_config) - else: - logging.warning( - "No mode specified (use --list_blobs or --gen_blobs). Generating by default..." + print( + f"Generating {len(work_items)} individual kernel files using {num_workers} workers..." ) - do_gen_blobs(args, gemm_multi_d_config) + print(f" Tile configs: {len(tile_configs)}") + print(f" Trait combinations: {len(trait_combos)}") + print(f" Total kernels: {len(work_items)}") + + # Show first few work items for debugging + if work_items: + print(" First work item example:") + tile_config, trait_combo = work_items[0][:2] + print(f" Tile config: {tile_config}") + print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits + + # Process work items in parallel + kernel_list = [] + completed = 0 + + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_workers + ) as executor: + # Submit all work items + print(f" Submitting {len(work_items)} tasks to executor...") + future_to_item = { + executor.submit(_generate_single_kernel_individual, item): item + for item in work_items + } + print(" All tasks submitted, waiting for completion...") + + # Collect results with progress reporting + for future in concurrent.futures.as_completed(future_to_item): + completed += 1 + if completed % 100 == 0 or completed == len(work_items): + print( + f" Progress: {completed}/{len(work_items)} kernels generated" + ) + + try: + result = future.result() + if result: + kernel_list.append(result) + except Exception as exc: + item = future_to_item[future] + print(f"Kernel generation failed for {item}: {exc}") + + # Sort kernel list for consistent ordering + kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name + + # Generate CMake include file for individual targets + self._generate_cmake_individual_targets(kernel_list) + + print( + f"Generated {len(kernel_list)} individual kernel files in {self.working_path}" + ) + + def _generate_cmake_individual_targets(self, kernel_list): + """Generate CMake include file that creates individual targets""" + cmake_code = f"""# Generated CMake file for individual GEMM Multi D targets + # Datatype: {self.datatype}, Layout: {self.layout} + """ + + for kernel_name, trait_combo, tile_config in kernel_list: + pipeline, epilogue, scheduler = trait_combo[:3] + + # Format tile config for CMake function + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + trait_str = f"{pipeline}_{epilogue}_{scheduler}_" + "_".join( + str(x) for x in trait_combo[3:] + ) + + cmake_code += f'create_individual_gemm_multi_d_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n' + + # Write CMake include file + with open( + self.working_path / "gemm_multi_d_individual_targets.cmake", "w" + ) as f: + f.write(cmake_code) -if __name__ == "__main__": +def _generate_single_kernel_individual(work_item): + """Worker function to generate a single individual kernel file""" + ( + tile_config, + trait_combo, + k_block_per_cu, + working_path, + gpu_target, + datatype, + layout, + elementwise_function, + config_json, + ) = work_item + + # Create a temporary builder instance for this worker + builder = GemmMultiDKernelBuilder( + working_path, + gpu_target, + datatype, + layout, + elementwise_function, + config_json, + ) + + try: + kernel_name, instance_code = builder._generate_kernel_instance( + tile_config, trait_combo, k_block_per_cu + ) + + # Create simplified filename without the "gemm_multi_d_" prefix + # Remove "gemm_multi_d_" from the beginning of kernel_name for the filename + simplified_name = kernel_name + if simplified_name.startswith("gemm_multi_d_"): + simplified_name = simplified_name[13:] # Remove "gemm_multi_d_" prefix + + # Write individual header file + header_file = working_path / f"gemm_multi_d_single_{simplified_name}.hpp" + with open(header_file, "w") as f: + f.write(instance_code) + + return (kernel_name, trait_combo, tile_config) + except Exception as e: + print(f"Error generating individual kernel: {e}") + return None + + +def main(): parser = argparse.ArgumentParser( - prog="generate", - description="gen API for CK gemm multi D kernel", + description="GEMM Multi D kernel instance builder with parallel support" ) + parser.add_argument("--working_path", required=True, help="Working directory path") + parser.add_argument("--gpu_target", required=True, help="GPU target architecture") parser.add_argument( - "-w", - "--working_path", - default="./", - required=False, - help="The path where all the blobs are going to be generated", - ) - parser.add_argument( - "--gpu_target", - required=True, - help="GPU target architecture", - ) - parser.add_argument( - "-j", - "--config_json", - required=False, - help="Path to the json which contains the configurations that user provide", - ) - parser.add_argument( - "-d", "--datatype", required=True, - help="Specify what datatype to use for the kernel generation, e.g. fp16", + choices=["fp16"], + help="Data type", ) parser.add_argument( - "-ly", "--layout", required=True, - help="Specify what layout to use for the kernel generation, e.g. rcrr, rrrr", + choices=["rcrr", "rrrr", "ccrr", "crrr"], + help="Matrix layout", ) parser.add_argument( - "-ef", "--elementwise_function", required=True, help="Specify what element wise function for D, e.g. mul, add, passthrough", ) + parser.add_argument("--config_json", help="Configuration JSON file") parser.add_argument( - "-l", - "--list_blobs", - action="store_true", - help="List all kernel instances to file", + "--num_workers", type=int, help="Number of parallel workers (default: auto)" ) parser.add_argument( - "-g", - "--gen_blobs", + "--gen_all_individual", action="store_true", - help="Generate all kernel instances into different files", + help="Generate individual kernel files", + ) + parser.add_argument( + "--gen_single", action="store_true", help="Generate a single kernel file" + ) + parser.add_argument("--kernel_name", help="Kernel name for single generation") + parser.add_argument( + "--tile_config", help="Tile configuration string for single generation" + ) + parser.add_argument( + "--trait_combo", help="Trait combination string for single generation" + ) + parser.add_argument( + "--list_kernels", + action="store_true", + help="List kernel configurations without generating files", ) args = parser.parse_args() - main(args) + assert args.datatype in ["fp16"], ( + f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16])" + ) + + layout_parts = args.layout.lower() + assert len(layout_parts) == 4, ( + f"Invalid layout string: {args.layout} (must be 4 characters like 'rcrr' where r stands for row major and c stands for column major)" + ) + assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], ( + f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)" + ) + assert layout_parts[2] == "r" and layout_parts[3] == "r", ( + f"Invalid matrix_c or d dimension in layout: {layout_parts[2]} andf {layout_parts[3]} (must be 'r' only as currently we are supporting only row major)" + ) + + # Elementwise function name validation + elementwise_function = args.elementwise_function.lower() + + valid_functions = ["mul", "add", "passthrough"] + if elementwise_function not in valid_functions: + raise ValueError( + f"Invalid elementwise function: {elementwise_function}. " + f"Valid options are: {', '.join(valid_functions)}" + ) + + # Set the function name based on the elementwise function + if elementwise_function == "mul": + function_name = "MultiDMultiply" + elif elementwise_function == "add": + function_name = "MultiDAdd" + elif elementwise_function == "passthrough": + function_name = "PassThrough" # TODO Change this + + args.elementwise_function = function_name + + # Create builder + builder = GemmMultiDKernelBuilder( + args.working_path, + args.gpu_target, + args.datatype, + args.layout, + args.elementwise_function, + args.config_json, + ) + + if args.list_kernels: + builder.write_kernel_list() + elif args.gen_single: + # Generate a single kernel file + if not args.kernel_name or not args.tile_config or not args.trait_combo: + parser.error( + "--gen_single requires --kernel_name, --tile_config, and --trait_combo" + ) + + # Parse tile config + tile_parts = args.tile_config.split("_") + tile_dims = tile_parts[0].split("x") + warp_dims = tile_parts[1].split("x") + warp_tile_dims = tile_parts[2].split("x") + + tile_config = { + "tile_m": int(tile_dims[0]), + "tile_n": int(tile_dims[1]), + "tile_k": int(tile_dims[2]), + "warp_m": int(warp_dims[0]), + "warp_n": int(warp_dims[1]), + "warp_k": int(warp_dims[2]), + "warp_tile_m": int(warp_tile_dims[0]), + "warp_tile_n": int(warp_tile_dims[1]), + "warp_tile_k": int(warp_tile_dims[2]), + } + + # Parse trait combo + trait_parts = args.trait_combo.split("_") + trait_combo = ( + trait_parts[0], # pipeline + trait_parts[1], # epilogue + trait_parts[2], # scheduler + trait_parts[3] == "True", # pad_m + trait_parts[4] == "True", # pad_n + trait_parts[5] == "True", # pad_k + trait_parts[6] == "True", # persistent + ) + + k_block_per_cu = builder.config.get("k_block_per_cu") + if k_block_per_cu is None: + k_block_per_cu = 1 + + # Generate the kernel + kernel_name, instance_code = builder._generate_kernel_instance( + tile_config, trait_combo, k_block_per_cu + ) + + # Write the file + simplified_name = kernel_name + if simplified_name.startswith("gemm_multi_d_"): + simplified_name = simplified_name[13:] + + header_file = ( + builder.working_path / f"gemm_multi_d_single_{simplified_name}.hpp" + ) + with open(header_file, "w") as f: + f.write(instance_code) + + print(f"Generated {header_file}") + + elif args.gen_all_individual: + # Generate all individual kernel files + builder.run(args.num_workers) + else: + parser.error( + "Must specify one of: --list_kernels, --gen_all_individual, or --gen_single" + ) + + +if __name__ == "__main__": + main() diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp index 0106d76c05..8e19c11c7d 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp @@ -9,7 +9,7 @@ #include "ck_tile/host/device_prop.hpp" #include "ck_tile/ops/gemm.hpp" -#include "benchmark_gemm_multi_d.hpp" +#include "gemm_multi_d_benchmark.hpp" class GemmMultiDProfiler { @@ -20,6 +20,25 @@ class GemmMultiDProfiler return instance; } + // Overload for single kernel benchmarking + void benchmark(GemmMultiDProblem& gemm_multi_d_problem, + std::function&, + const ck_tile::stream_config&)> kernel_func) + { + // Create a vector with a single callable that returns both name and time + std::vector( + ck_tile::GemmMultiDHostArgs&, const ck_tile::stream_config&)>> + callables; + + callables.push_back([kernel_func](ck_tile::GemmMultiDHostArgs& args, + const ck_tile::stream_config& stream) { + float time = kernel_func(args, stream); + return std::make_tuple(std::string(KERNEL_NAME), time); + }); + + benchmark(gemm_multi_d_problem, callables); + } + void benchmark( GemmMultiDProblem& gemm_multi_d_problem, std::vector( @@ -30,7 +49,7 @@ class GemmMultiDProfiler const BLayout layout_b = BLayout{}; const D0Layout layout_d0 = D0Layout{}; const D1Layout layout_d1 = D1Layout{}; - const ELayout layout_e = ELayout{}; + const CLayout layout_c = CLayout{}; gemm_multi_d_problem.stride_a_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_, gemm_multi_d_problem.k_, @@ -50,10 +69,10 @@ class GemmMultiDProfiler gemm_multi_d_problem.n_, gemm_multi_d_problem.stride_d1_, is_row_major(layout_d1)); - gemm_multi_d_problem.stride_e_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_, + gemm_multi_d_problem.stride_c_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_, gemm_multi_d_problem.n_, - gemm_multi_d_problem.stride_e_, - is_row_major(layout_e)); + gemm_multi_d_problem.stride_c_, + is_row_major(layout_c)); ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_, @@ -75,30 +94,30 @@ class GemmMultiDProfiler gemm_multi_d_problem.n_, gemm_multi_d_problem.stride_d1_, is_row_major(layout_d1))); - ck_tile::HostTensor e_m_n_device_result( + ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_, gemm_multi_d_problem.n_, - gemm_multi_d_problem.stride_e_, - is_row_major(layout_e))); + gemm_multi_d_problem.stride_c_, + is_row_major(layout_c))); ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n); - ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n); ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n.get_element_space_size_in_bytes()); - ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); a_m_k_dev_buf.ToDevice(a_m_k.mData.data()); b_k_n_dev_buf.ToDevice(b_k_n.mData.data()); d0_m_n_dev_buf.ToDevice(d0_m_n.mData.data()); d1_m_n_dev_buf.ToDevice(d1_m_n.mData.data()); - e_m_n_dev_buf.SetZero(); - e_m_n_device_result.SetZero(); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); std::array ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(), d1_m_n_dev_buf.GetDeviceBuffer()}; @@ -110,7 +129,7 @@ class GemmMultiDProfiler a_m_k_dev_buf.GetDeviceBuffer(), b_k_n_dev_buf.GetDeviceBuffer(), ds_ptr_buf, - e_m_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), gemm_multi_d_problem.split_k_, gemm_multi_d_problem.m_, gemm_multi_d_problem.n_, @@ -118,19 +137,19 @@ class GemmMultiDProfiler gemm_multi_d_problem.stride_a_, gemm_multi_d_problem.stride_b_, stridesDs, - gemm_multi_d_problem.stride_e_, + gemm_multi_d_problem.stride_c_, }; - ck_tile::HostTensor e_m_n_host_result( + ck_tile::HostTensor c_m_n_host_result( ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_, gemm_multi_d_problem.n_, - gemm_multi_d_problem.stride_e_, - is_row_major(layout_e))); + gemm_multi_d_problem.stride_c_, + is_row_major(layout_c))); if(setting_.verify_) { gemm_multi_d_host_reference( - setting_.verify_, a_m_k, b_k_n, d0_m_n, d1_m_n, e_m_n_host_result); + setting_.verify_, a_m_k, b_k_n, d0_m_n, d1_m_n, c_m_n_host_result); } for(auto& callable : callables) @@ -139,54 +158,58 @@ class GemmMultiDProfiler callable(gemm_multi_d_args, ck_tile::stream_config{ nullptr, true, setting_.log_, setting_.n_warmup_, setting_.n_repeat_}); - - auto [kernel_name, execution_time] = kernel_run_result; - process_result(gemm_multi_d_problem, - e_m_n_dev_buf, - e_m_n_host_result, - e_m_n_device_result, + c_m_n_dev_buf, + c_m_n_host_result, + c_m_n_dev_result, kernel_run_result); } } void process_result(const GemmMultiDProblem& gemm_multi_d_problem, - ck_tile::DeviceMem& e_m_n_dev_buf, - ck_tile::HostTensor& e_m_n_host_result, - ck_tile::HostTensor& e_m_n_dev_result, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, const std::tuple& kernel_run_result) { auto [name, avg_time] = kernel_run_result; KernelInstance kernel_instance{name, gemm_multi_d_problem, {-1.0f, -1.0f, -1.0f}}; - static constexpr ck_tile::index_t NumDTensor = DsDataType::size(); - std::size_t flop = 0, num_byte = 0; - flop += std::size_t(2) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_ * - gemm_multi_d_problem.k_; - ck_tile::static_for<0, NumDTensor, 1>{}([&](auto i) { + // compute performance metric + std::size_t flop = std::size_t(2) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_ * + gemm_multi_d_problem.k_; + std::size_t num_byte = + sizeof(ADataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.k_ + + sizeof(BDataType) * gemm_multi_d_problem.n_ * gemm_multi_d_problem.k_ + + sizeof(CDataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_; + + // Dth Dimension Updates + ck_tile::static_for<0, DsDataType::size(), 1>{}([&](auto i) { num_byte += sizeof(ck_tile::remove_cvref_t>) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_; flop += sizeof(ck_tile::remove_cvref_t>) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_; }); - num_byte += sizeof(ADataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.k_ + - sizeof(BDataType) * gemm_multi_d_problem.k_ * gemm_multi_d_problem.n_ + - sizeof(EDataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_; + // update kernel_instance.perf_result_.latency_ = avg_time; kernel_instance.perf_result_.tflops_ = static_cast(flop) / 1.E9 / avg_time; kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time; - if(setting_.log_ > 0) + if(setting_.log_ > 0 && !setting_.json_output_) { std::cout << kernel_instance << std::endl; } - e_m_n_dev_buf.FromDevice(e_m_n_dev_result.data()); + // verify result + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool verified_correct = - !setting_.verify_ || - compare(name, gemm_multi_d_problem.k_, e_m_n_dev_result, e_m_n_host_result); + !setting_.verify_ || compare(name, + gemm_multi_d_problem.k_, + 1, // Multi d currently supports only k_batch = 1 + c_m_n_dev_result, + c_m_n_host_result); if(verified_correct) { @@ -197,8 +220,9 @@ class GemmMultiDProfiler std::cout << "Verification failed, skip kernel: " << name << std::endl; } - e_m_n_dev_buf.SetZero(); - e_m_n_dev_result.SetZero(); + // clear tensor + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); } KernelInstance select_best_instance(Metric metric) @@ -213,10 +237,18 @@ class GemmMultiDProfiler b.perf_result_, a.perf_result_, metric); }); - std::cout << "**********************************" << std::endl; - std::cout << "According to given metrics: " << get_metric_name(metric) << "\n" - << "The best kernel instance is: " << kernel_instance << std::endl; - std::cout << "**********************************" << std::endl; + if(setting_.json_output_) + { + // Output clean JSON only + std::cout << kernel_instance << std::endl; + } + else + { + std::cout << "**********************************" << std::endl; + std::cout << "According to given metrics: " << get_metric_name(metric) << "\n" + << "Current kernel performance is: " << kernel_instance << std::endl; + std::cout << "**********************************" << std::endl; + } if(!setting_.csv_filename_.empty()) { @@ -244,16 +276,13 @@ class GemmMultiDProfiler file << get_rocm_version() << "," << ck_tile::get_device_name() << "," << problem.split_k_ << "," << problem.m_ << "," << problem.n_ << "," << problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << "," - << problem.stride_d0_ << "," << problem.stride_d1_ << "," << problem.stride_e_ - << "," << problem.dtype_a_ << "," << problem.dtype_b_ << "," - << problem.dtype_d0_ << "," << problem.dtype_d1_ << "," << problem.dtype_acc_ - << "," << problem.dtype_e_ << "," << problem.layout_a_ << "," - << problem.layout_b_ << "," << problem.layout_d0_ << "," << problem.layout_d1_ - << "," << problem.layout_e_ << "," << "," << name << "," << std::fixed - << std::setprecision(4) << perf.latency_ << "," << std::fixed - << std::setprecision(4) << perf.tflops_ << "," << std::fixed - << std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric) - << "\n"; + << problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_ + << "," << problem.dtype_acc_ << "," << problem.dtype_c_ << "," + << problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_ + << "," << name << "," << std::fixed << std::setprecision(4) << perf.latency_ + << "," << std::fixed << std::setprecision(4) << perf.tflops_ << "," + << std::fixed << std::setprecision(4) << perf.bandwidth_ << "," + << get_metric_name(metric) << "\n"; if(!file) { diff --git a/tile_engine/ops/gemm_preshuffle/CMakeLists.txt b/tile_engine/ops/gemm_preshuffle/CMakeLists.txt index d80d2661d1..e3bee6ff52 100644 --- a/tile_engine/ops/gemm_preshuffle/CMakeLists.txt +++ b/tile_engine/ops/gemm_preshuffle/CMakeLists.txt @@ -1,4 +1,4 @@ -set(GEMM_PRESHUFFLE_DATATYPE "fp16;fp8" CACHE STRING "List of datatypes for GEMM Preshuffle (semicolon-separated)") +set(GEMM_PRESHUFFLE_DATATYPE "fp16;fp8;bf16;bf8" CACHE STRING "List of datatypes for GEMM Preshuffle (semicolon-separated)") set(GEMM_PRESHUFFLE_LAYOUT "rcr" CACHE STRING "List of layout for GEMM Preshuffle (semicolon-separated)") set(GEMM_PRESHUFFLE_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") option(ENABLE_CCACHE_GEMM_PRESHUFFLE "Enable ccache for GEMM Preshuffle ops compilation" OFF) @@ -122,15 +122,15 @@ function(build_individual_gemm_preshuffle_targets datatype layout) if(DEFINED ENV{GEMM_PRESHUFFLE_CONFIG_FILE} AND NOT "$ENV{GEMM_PRESHUFFLE_CONFIG_FILE}" STREQUAL "") set(config_filename "$ENV{GEMM_PRESHUFFLE_CONFIG_FILE}") set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}") - message(STATUS " Using config from environment variable: ${config_filename}") + message(VERBOSE " Using config from environment variable: ${config_filename}") elseif(NOT "${GEMM_PRESHUFFLE_CONFIG_FILE}" STREQUAL "") # Use CMake variable if set set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_PRESHUFFLE_CONFIG_FILE}") - message(STATUS " Using custom config: ${GEMM_PRESHUFFLE_CONFIG_FILE}") + message(VERBOSE " Using custom config: ${GEMM_PRESHUFFLE_CONFIG_FILE}") else() # Use default config for all layouts set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") - message(STATUS " Using default config for layout ${layout}") + message(VERBOSE " Using default config for layout ${layout}") endif() # Check if config file exists @@ -151,18 +151,18 @@ function(build_individual_gemm_preshuffle_targets datatype layout) endif() # Generate individual kernel files using parallel version - message(STATUS "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...") - message(STATUS " Working path: ${working_path}") - message(STATUS " Config file: ${json_blob}") - message(STATUS " Python executable: ${Python3_EXECUTABLE}") - message(STATUS " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_preshuffle_instance_builder.py") + message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...") + message(VERBOSE " Working path: ${working_path}") + message(VERBOSE " Config file: ${json_blob}") + message(VERBOSE " Python executable: ${Python3_EXECUTABLE}") + message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_preshuffle_instance_builder.py") # Create working directory first file(MAKE_DIRECTORY ${working_path}) # First, just list the kernels (fast operation) - message(STATUS " Listing kernel configurations...") - message(STATUS " GPU Targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}") + message(VERBOSE " Listing kernel configurations...") + message(VERBOSE " GPU Targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}") execute_process( COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_preshuffle_instance_builder.py --working_path ${working_path} @@ -185,7 +185,7 @@ function(build_individual_gemm_preshuffle_targets datatype layout) if(EXISTS ${working_path}/gemm_preshuffle_kernel_count.txt) file(READ ${working_path}/gemm_preshuffle_kernel_count.txt kernel_count) string(STRIP "${kernel_count}" kernel_count) - message(STATUS " Found ${kernel_count} kernel configurations") + message(VERBOSE " Found ${kernel_count} kernel configurations") else() message(FATAL_ERROR "Kernel count file not found") endif() @@ -209,10 +209,10 @@ function(build_individual_gemm_preshuffle_targets datatype layout) endfunction() # Main build logic - Only individual builds supported -message(STATUS "=== Starting Tile Engine GEMM Preshuffle Configuration ===") -message(STATUS "GEMM_PRESHUFFLE_DATATYPE: ${GEMM_PRESHUFFLE_DATATYPE}") -message(STATUS "GEMM_PRESHUFFLE_LAYOUT: ${GEMM_PRESHUFFLE_LAYOUT}") -message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") +message(VERBOSE "=== Starting Tile Engine GEMM Preshuffle Configuration ===") +message(VERBOSE "GEMM_PRESHUFFLE_DATATYPE: ${GEMM_PRESHUFFLE_DATATYPE}") +message(VERBOSE "GEMM_PRESHUFFLE_LAYOUT: ${GEMM_PRESHUFFLE_LAYOUT}") +message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") # Filter GPU targets to only gfx90a, gfx942, and gfx950 set(GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL "") @@ -221,7 +221,7 @@ set(DESIRED_TARGETS "gfx90a;gfx942;gfx950") foreach(target IN LISTS SUPPORTED_GPU_TARGETS) if(target IN_LIST DESIRED_TARGETS) list(APPEND GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL ${target}) - message(STATUS " Adding GPU target: ${target}") + message(VERBOSE " Adding GPU target: ${target}") endif() endforeach() @@ -229,7 +229,7 @@ endforeach() if(NOT GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL) message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") else() - message(STATUS "Building individual GEMM Preshuffle targets for GPU targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}") + message(VERBOSE "Building individual GEMM Preshuffle targets for GPU targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}") # Enable parallel compilation optimizations # Set up job pools for better parallel compilation control @@ -244,12 +244,12 @@ else() find_program(CCACHE_PROGRAM ccache) if(CCACHE_PROGRAM) set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) - message(STATUS "Using ccache for faster compilation") + message(VERBOSE "Using ccache for faster compilation") else() message(WARNING "ccache requested but not found") endif() else() - message(STATUS "ccache disabled for GEMM Preshuffle ops (use -DENABLE_CCACHE_GEMM_PRESHUFFLE=ON to enable)") + message(VERBOSE "ccache disabled for GEMM Preshuffle ops (use -DENABLE_CCACHE_GEMM_PRESHUFFLE=ON to enable)") endif() # Create master collection targets