diff --git a/Jenkinsfile b/Jenkinsfile index 1f3ca6d78b..cd5f8b60c2 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1637,18 +1637,18 @@ pipeline { -D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \ -D CMAKE_BUILD_TYPE=Release \ -D GPU_TARGETS="gfx90a" \ - -D GEMM_DATATYPE="fp8;fp16" \ - -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ + -D GEMM_UNIVERSAL_DATATYPE="fp8;fp16" \ + -D GEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" \ -D GEMM_STREAMK_DATATYPE="fp8;fp16" \ -D GEMM_STREAMK_LAYOUT="rcr" \ -D GEMM_MULTI_D_DATATYPE="fp16" \ -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \ -D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \ - ninja -j64 benchmark_gemm_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all && \ - python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ - python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ - python3 ../tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """ + ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all && \ + python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ + python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ + python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args) @@ -1668,18 +1668,18 @@ pipeline { -D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \ -D CMAKE_BUILD_TYPE=Release \ -D GPU_TARGETS="gfx942" \ - -D GEMM_DATATYPE="fp8;fp16" \ - -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ + -D GEMM_UNIVERSAL_DATATYPE="fp8;fp16" \ + -D GEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" \ -D GEMM_STREAMK_DATATYPE="fp8;fp16" \ -D GEMM_STREAMK_LAYOUT="rcr" \ -D GEMM_MULTI_D_DATATYPE="fp16" \ -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \ -D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \ - ninja -j64 benchmark_gemm_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all && \ - python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ - python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ - python3 ../tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """ + ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all && \ + python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ + python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ + python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args) @@ -1699,10 +1699,10 @@ pipeline { -D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \ -D CMAKE_BUILD_TYPE=Release \ -D GPU_TARGETS="gfx1201" \ - -D GEMM_DATATYPE="fp16" \ - -D GEMM_LAYOUT="rcr;rrr;crr;ccr" .. && \ - ninja -j64 benchmark_gemm_all && \ - python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """ + -D GEMM_UNIVERSAL_DATATYPE="fp16" \ + -D GEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" .. && \ + ninja -j${nthreads()} benchmark_gemm_universal_all && \ + python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args) diff --git a/tile_engine/CMakeLists.txt b/tile_engine/CMakeLists.txt index 7f5e2fa298..f63453e21b 100644 --- a/tile_engine/CMakeLists.txt +++ b/tile_engine/CMakeLists.txt @@ -5,4 +5,6 @@ include_directories(BEFORE ${CMAKE_CURRENT_LIST_DIR}/include ) -add_subdirectory(ops) +add_subdirectory(ops/gemm) +add_subdirectory(ops/gemm_streamk) + diff --git a/tile_engine/ops/CMakeLists.txt b/tile_engine/ops/CMakeLists.txt deleted file mode 100644 index 6f82e1b07a..0000000000 --- a/tile_engine/ops/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -add_subdirectory(gemm) -add_subdirectory(gemm_multi_d) -add_subdirectory(gemm_preshuffle) -add_subdirectory(gemm_streamk) diff --git a/tile_engine/ops/commons/test_benchmark.sh b/tile_engine/ops/commons/test_benchmark.sh deleted file mode 100755 index e2e0324da8..0000000000 --- a/tile_engine/ops/commons/test_benchmark.sh +++ /dev/null @@ -1,105 +0,0 @@ -#!/bin/bash -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - - -# Test script for tile engine GEMM benchmarks -# This script demonstrates how to run the new individual benchmark executables - -# Colors for output -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -NC='\033[0m' # No Color - -# Find the build directory -if [ -z "$1" ]; then - # Try to find build directory automatically - BUILD_DIR=$(find /root/workspace/composable_kernel -name "test_gemm_fix" -type d 2>/dev/null | head -1) - if [ -z "$BUILD_DIR" ]; then - echo -e "${RED}Error: Could not find build directory. Please provide it as first argument.${NC}" - echo "Usage: $0 " - exit 1 - fi -else - BUILD_DIR="$1" -fi - -echo -e "${GREEN}Using build directory: $BUILD_DIR${NC}" - -# Check if bin directory exists -if [ ! -d "$BUILD_DIR/bin" ]; then - echo -e "${RED}Error: bin directory not found in $BUILD_DIR${NC}" - exit 1 -fi - -# Find all benchmark executables -echo -e "${YELLOW}Finding benchmark executables...${NC}" -BENCHMARKS=$(find "$BUILD_DIR/bin" -name "benchmark_gemm_*" -type f 2>/dev/null) - -if [ -z "$BENCHMARKS" ]; then - echo -e "${RED}No benchmark executables found in $BUILD_DIR/bin${NC}" - echo "Please build some benchmarks first with:" - echo " cd $BUILD_DIR" - echo " make benchmark_gemm_" - exit 1 -fi - -# Count benchmarks -NUM_BENCHMARKS=$(echo "$BENCHMARKS" | wc -l) -echo -e "${GREEN}Found $NUM_BENCHMARKS benchmark executable(s)${NC}" - -# Test sizes -SIZES=(512 1024 2048) - -# Results file -RESULTS_FILE="benchmark_results_$(date +%Y%m%d_%H%M%S).csv" - -echo -e "${YELLOW}Running benchmarks...${NC}" -echo "Results will be saved to: $RESULTS_FILE" - -# Run each benchmark -COUNTER=0 -for BENCH in $BENCHMARKS; do - COUNTER=$((COUNTER + 1)) - BENCH_NAME=$(basename "$BENCH") - echo -e "\n${GREEN}[$COUNTER/$NUM_BENCHMARKS] Running: $BENCH_NAME${NC}" - - for SIZE in "${SIZES[@]}"; do - echo -e " Testing size: ${SIZE}x${SIZE}x${SIZE}" - - # Run with verification - "$BENCH" -m=$SIZE -n=$SIZE -k=$SIZE -verify=2 -warmup=10 -repeat=20 \ - -csv_filename="$RESULTS_FILE" -csv_format=simple \ - 2>&1 | grep -E "(Time:|Performance:|Verification:|Error)" - - if [ ${PIPESTATUS[0]} -ne 0 ]; then - echo -e " ${RED}Benchmark failed!${NC}" - fi - done -done - -echo -e "\n${GREEN}Benchmark testing complete!${NC}" -echo "Results saved to: $RESULTS_FILE" - -# Show summary if CSV file exists -if [ -f "$RESULTS_FILE" ]; then - echo -e "\n${YELLOW}Summary of results:${NC}" - echo "Number of tests: $(tail -n +2 "$RESULTS_FILE" | wc -l)" - echo "Successful tests: $(grep -c "true" "$RESULTS_FILE")" - echo "Failed tests: $(grep -c "false" "$RESULTS_FILE")" -fi - -# Example of running a specific benchmark with different options -echo -e "\n${YELLOW}Example commands for manual testing:${NC}" -echo "# Basic run:" -echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024" -echo "" -echo "# With CPU verification:" -echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024 -verify=1" -echo "" -echo "# JSON output for parsing:" -echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024 -json_output=true" -echo "" -echo "# Performance testing with TFLOPS metric:" -echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=4096 -n=4096 -k=4096 -warmup=100 -repeat=200 -metric=1" diff --git a/tile_engine/ops/commons/test_validation.py b/tile_engine/ops/commons/test_validation.py deleted file mode 100644 index 46fb008c27..0000000000 --- a/tile_engine/ops/commons/test_validation.py +++ /dev/null @@ -1,144 +0,0 @@ -#!/usr/bin/env python -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -""" -Test script to verify that the validation logic is working correctly. -""" - -from validation_utils import ( - is_tile_config_valid, - is_trait_combination_valid, - validate_warp_tile_combination, -) - - -def test_warp_tile_validation(): - """Test warp tile combination validation""" - print("Testing warp tile combination validation...") - - # Get GPU name - gpu_name = "gfx90a" - - # Test cases for fp16 - test_cases = [ - # (warp_tile_m, warp_tile_n, warp_tile_k, expected_valid) - ([4, 64, 8], False), # Invalid - not in supported list - ([4, 64, 16], True), # Valid - ([32, 32, 8], True), # Valid - ([16, 16, 16], True), # Valid - ([32, 32, 16], True), # Valid - ([16, 16, 32], True), # Valid - ([64, 4, 16], True), # Valid - ([128, 128, 128], False), # Invalid - too large - ] - - print("\nTesting fp16 warp tile combinations:") - for (warp_tile_m, warp_tile_n, warp_tile_k), expected in test_cases: - valid, msg = validate_warp_tile_combination( - warp_tile_m, warp_tile_n, warp_tile_k, "fp16", "fp16", "fp16", gpu_name - ) - status = "PASS" if valid == expected else "FAIL" - print(f" [{warp_tile_m}, {warp_tile_n}, {warp_tile_k}]: {valid} - {status}") - if not valid and msg: - print(f" Reason: {msg}") - - -def test_trait_combinations(): - """Test trait combination validation""" - print("\n\nTesting trait combination validation...") - - test_cases = [ - # (pipeline, epilogue, scheduler, expected_valid) - ("mem", "default", "intrawave", True), - ("mem", "cshuffle", "intrawave", True), - ("compv3", "default", "interwave", False), # Invalid combination - ("compv3", "cshuffle", "interwave", False), # Invalid combination - ("compv4", "default", "interwave", False), # Invalid combination - ("compv4", "cshuffle", "interwave", False), # Invalid combination - ("compv3", "default", "intrawave", True), - ("compv4", "cshuffle", "intrawave", True), - ] - - print("\nTesting trait combinations:") - for pipeline, epilogue, scheduler, expected in test_cases: - valid = is_trait_combination_valid(pipeline, epilogue, scheduler) - status = "PASS" if valid == expected else "FAIL" - print(f" {pipeline}-{epilogue}-{scheduler}: {valid} - {status}") - - -def test_full_tile_config_validation(): - """Test full tile configuration validation""" - print("\n\nTesting full tile configuration validation...") - - # Test case that was failing in the build - tile_m, tile_n, tile_k = 256, 256, 32 - warp_m, warp_n, warp_k = 1, 4, 1 - warp_tile_m, warp_tile_n, warp_tile_k = 4, 64, 8 - - print("\nTesting problematic configuration:") - print(f" Tile: {tile_m}x{tile_n}x{tile_k}") - print(f" Warp: {warp_m}x{warp_n}x{warp_k}") - print(f" WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}") - - valid = is_tile_config_valid( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - "fp16", - "fp16", - "fp16", - "mem", - ) - - print(f" Valid: {valid}") - print(" Expected: False (warp tile [4, 64, 8] is not supported for fp16)") - - # Test a valid configuration - warp_tile_k = 16 # Change to valid value - print("\nTesting corrected configuration:") - print(f" WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}") - - valid = is_tile_config_valid( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - "fp16", - "fp16", - "fp16", - "mem", - ) - - print(f" Valid: {valid}") - print(" Expected: True") - - -def main(): - """Run all tests""" - print("=" * 60) - print("GEMM Validation Test Suite") - print("=" * 60) - - test_warp_tile_validation() - test_trait_combinations() - test_full_tile_config_validation() - - print("\n" + "=" * 60) - print("Test suite completed") - print("=" * 60) - - -if __name__ == "__main__": - main() diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index ff18291c00..94bb928f79 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -1,310 +1,6 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)") -set(GEMM_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM (semicolon-separated)") -set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") -option(ENABLE_CCACHE_GEMM "Enable ccache for GEMM ops compilation" OFF) - -# Store the directory path for use in functions -set(GEMM_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) - -# Function to create individual GEMM targets -function(create_individual_gemm_target datatype layout trait tile_config config_json) - # Use the parent scope GEMM_GPU_TARGETS_INDIVIDUAL variable - if(NOT GEMM_GPU_TARGETS_INDIVIDUAL) - message(WARNING "Skipping individual GEMM target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets") - return() - endif() - - # Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k - # First split by underscore to get three groups - string(REPLACE "_" ";" config_groups ${tile_config}) - list(GET config_groups 0 tile_dims) # e.g., 256x256x32 - list(GET config_groups 1 warp_dims) # e.g., 4x1x1 - list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16 - - # Parse tile dimensions - string(REPLACE "x" ";" tile_parts ${tile_dims}) - list(GET tile_parts 0 tile_m) - list(GET tile_parts 1 tile_n) - list(GET tile_parts 2 tile_k) - - # Parse warp dimensions - string(REPLACE "x" ";" warp_parts ${warp_dims}) - list(GET warp_parts 0 warp_m) - list(GET warp_parts 1 warp_n) - list(GET warp_parts 2 warp_k) - - # Parse warp tile dimensions - string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims}) - list(GET warp_tile_parts 0 warp_tile_m) - list(GET warp_tile_parts 1 warp_tile_n) - list(GET warp_tile_parts 2 warp_tile_k) - - set(target_name "benchmark_gemm_${datatype}_${layout}_${trait}_${tile_config}") - set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") - - # Generate the single instance header for this kernel - set(instance_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") - - # Add custom command to generate the header file at build time - add_custom_command( - OUTPUT ${instance_header} - COMMAND ${Python3_EXECUTABLE} ${GEMM_SOURCE_DIR}/gemm_instance_builder.py - --working_path ${working_path} - --datatype ${datatype} - --layout ${layout} - --config_json ${config_json} - --gen_single - --kernel_name "gemm_${datatype}_${layout}_${trait}_${tile_config}" - --tile_config "${tile_config}" - --trait_combo "${trait}" - --gpu_target "${GEMM_GPU_TARGETS_INDIVIDUAL}" - DEPENDS ${GEMM_SOURCE_DIR}/gemm_instance_builder.py ${config_json} - COMMENT "Generating ${instance_header}" - ) - - # Create the executable - add_executable(${target_name} - # to save build time, exclude the target from "all" target of "gemm" directory and its ancestors - EXCLUDE_FROM_ALL - ${GEMM_SOURCE_DIR}/gemm_benchmark_single.cpp - ${instance_header} - ) - - # Set GPU architectures - set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS_INDIVIDUAL}) - - # Set compile definitions - target_compile_definitions(${target_name} PRIVATE - GEMM_SINGLE_INSTANCE_HPP="${instance_header}" - ) - - # Include directories - target_include_directories(${target_name} PRIVATE - ${GEMM_SOURCE_DIR} - ${working_path} - ) - - # Compile options - target_compile_options(${target_name} PRIVATE - -Wno-undefined-func-template - -Wno-float-equal - --offload-compress - -include ${instance_header} - ) - - # Add to collection targets - add_dependencies(benchmark_gemm_all ${target_name}) - add_dependencies(benchmark_gemm_${datatype} ${target_name}) - add_dependencies(benchmark_gemm_${layout} ${target_name}) - add_dependencies(benchmark_gemm_${datatype}_${layout} ${target_name}) - - # Add to trait-specific targets - string(REPLACE "_" ";" trait_parts ${trait}) - list(GET trait_parts 0 pipeline) - list(GET trait_parts 1 epilogue) - list(GET trait_parts 2 scheduler) - - add_dependencies(benchmark_gemm_${pipeline}_pipeline ${target_name}) - add_dependencies(benchmark_gemm_${epilogue}_epilogue ${target_name}) - add_dependencies(benchmark_gemm_${scheduler}_scheduler ${target_name}) -endfunction() - -# Function to build individual GEMM targets -function(build_individual_gemm_targets datatype layout) - set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") - - # Choose config file - # Priority order: - # 1. Environment variable GEMM_CONFIG_FILE - # 2. CMake variable GEMM_CONFIG_FILE - # 3. Default based on layout - - # Check environment variable first - if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "") - set(config_filename "$ENV{GEMM_CONFIG_FILE}") - set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}") - message(VERBOSE " Using config from environment variable: ${config_filename}") - elseif(NOT "${GEMM_CONFIG_FILE}" STREQUAL "") - # Use CMake variable if set - set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_CONFIG_FILE}") - message(VERBOSE " Using custom config: ${GEMM_CONFIG_FILE}") - else() - # Use default config for all layouts - set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") - message(VERBOSE " Using default config for layout ${layout}") - endif() - - # Check if config file exists - if(NOT EXISTS ${json_blob}) - message(FATAL_ERROR "Config file not found: ${json_blob}") - endif() - - # Determine number of workers for parallel generation - if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL}) - set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL}) - else() - # Use processor count but limit to avoid memory issues - cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES) - math(EXPR num_workers "${num_cores}") - if(num_workers GREATER 8) - set(num_workers 8) - endif() - endif() - - # Generate individual kernel files using parallel version - message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...") - message(VERBOSE " Working path: ${working_path}") - message(VERBOSE " Config file: ${json_blob}") - message(VERBOSE " Python executable: ${Python3_EXECUTABLE}") - message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py") - - # Create working directory first - file(MAKE_DIRECTORY ${working_path}) - - message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py - --working_path ${working_path} - --datatype ${datatype} - --layout ${layout} - --config_json ${json_blob} - --gpu_target ${GEMM_GPU_TARGETS_INDIVIDUAL} - --list_kernels ") - - # First, just list the kernels (fast operation) - message(VERBOSE " Listing kernel configurations...") - execute_process( - COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py - --working_path ${working_path} - --datatype ${datatype} - --layout ${layout} - --config_json ${json_blob} - --gpu_target ${GEMM_GPU_TARGETS_INDIVIDUAL} - --list_kernels - WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} - RESULT_VARIABLE ret - OUTPUT_VARIABLE list_output - ERROR_VARIABLE list_error - ) - - if(NOT ret EQUAL 0) - message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}") - endif() - - # Read kernel count - if(EXISTS ${working_path}/gemm_kernel_count.txt) - file(READ ${working_path}/gemm_kernel_count.txt kernel_count) - string(STRIP "${kernel_count}" kernel_count) - message(VERBOSE " Found ${kernel_count} kernel configurations") - else() - message(FATAL_ERROR "Kernel count file not found") - endif() - - # Read kernel list and create targets - if(EXISTS ${working_path}/gemm_kernel_list.txt) - file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) - foreach(line IN LISTS kernel_lines) - # Parse line: kernel_name|tile_config|trait_combo - string(REPLACE "|" ";" parts "${line}") - list(GET parts 0 kernel_name) - list(GET parts 1 tile_config) - list(GET parts 2 trait_combo) - - # Create individual target - create_individual_gemm_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}") - endforeach() - else() - message(FATAL_ERROR "Kernel list file not found") - endif() -endfunction() - -# Main build logic - Only individual builds supported -message(VERBOSE "=== Starting Tile Engine GEMM Configuration ===") -message(VERBOSE "GEMM_DATATYPE: ${GEMM_DATATYPE}") -message(VERBOSE "GEMM_LAYOUT: ${GEMM_LAYOUT}") -message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") - -# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201 -set(GEMM_GPU_TARGETS_INDIVIDUAL "") -set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201") - -foreach(target IN LISTS SUPPORTED_GPU_TARGETS) - if(target IN_LIST DESIRED_TARGETS) - list(APPEND GEMM_GPU_TARGETS_INDIVIDUAL ${target}) - message(VERBOSE " Adding GPU target: ${target}") - endif() -endforeach() - -# Skip build if no matching targets found -if(NOT GEMM_GPU_TARGETS_INDIVIDUAL) - message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") -else() - message(VERBOSE "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}") - - # Enable parallel compilation optimizations - # Set up job pools for better parallel compilation control - set_property(GLOBAL PROPERTY JOB_POOLS - compile_heavy=4 # Limit heavy compilations to prevent OOM - compile_normal=16 # Allow more parallel normal compilations - ) - - # Enable compiler cache if available and explicitly requested - # Disabled by default due to permission issues in CI environments - if(ENABLE_CCACHE_GEMM) - find_program(CCACHE_PROGRAM ccache) - if(CCACHE_PROGRAM) - set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) - message(VERBOSE "Using ccache for faster compilation") - else() - message(WARNING "ccache requested but not found") - endif() - else() - message(VERBOSE "ccache disabled for GEMM ops (use -DENABLE_CCACHE_GEMM=ON to enable)") - endif() - - # Create master collection targets - add_custom_target(benchmark_gemm_all) - - # Create datatype collection targets - foreach(dt IN LISTS GEMM_DATATYPE) - add_custom_target(benchmark_gemm_${dt}) - endforeach() - - # Create layout collection targets - foreach(l IN LISTS GEMM_LAYOUT) - add_custom_target(benchmark_gemm_${l}) - endforeach() - - # Create combined collection targets - foreach(dt IN LISTS GEMM_DATATYPE) - foreach(l IN LISTS GEMM_LAYOUT) - add_custom_target(benchmark_gemm_${dt}_${l}) - endforeach() - endforeach() - - # Create trait-based collection targets - # These are common trait components used across all GEMM kernels - set(GEMM_PIPELINES "mem;compv3;compv4") - set(GEMM_EPILOGUES "default;cshuffle") - set(GEMM_SCHEDULERS "intrawave;interwave") - - foreach(pipeline IN LISTS GEMM_PIPELINES) - add_custom_target(benchmark_gemm_${pipeline}_pipeline) - endforeach() - - foreach(epilogue IN LISTS GEMM_EPILOGUES) - add_custom_target(benchmark_gemm_${epilogue}_epilogue) - endforeach() - - foreach(scheduler IN LISTS GEMM_SCHEDULERS) - add_custom_target(benchmark_gemm_${scheduler}_scheduler) - endforeach() - - # Build individual targets for each datatype/layout combination - foreach(dt IN LISTS GEMM_DATATYPE) - foreach(l IN LISTS GEMM_LAYOUT) - build_individual_gemm_targets(${dt} ${l}) - endforeach() - endforeach() -endif() +add_subdirectory(gemm_universal) +add_subdirectory(gemm_multi_d) +add_subdirectory(gemm_preshuffle) \ No newline at end of file diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md deleted file mode 100644 index ce62f8dca5..0000000000 --- a/tile_engine/ops/gemm/README.md +++ /dev/null @@ -1,442 +0,0 @@ -# CK Tile Engine GEMM Operations - -## Overview - -The CK Tile Engine GEMM module provides a comprehensive system for generating, building, and benchmarking GEMM (General Matrix Multiplication) kernels with various configurations. It supports multiple data types, layouts, and optimization strategies. The system has evolved from a monolithic build approach (where all kernels compile into a single executable) to a more flexible individual kernel compilation system, providing better build parallelism and targeted testing capabilities. - -## Table of Contents - -1. [Build System Architecture](#build-system-architecture) -2. [Build Instructions](#build-instructions) -3. [Running Benchmarks](#running-benchmarks) -4. [Configuration System](#configuration-system) -5. [Scripts and Tools](#scripts-and-tools) -6. [Command Line Options](#command-line-options) -7. [Understanding Kernel Names](#understanding-kernel-names) -8. [Troubleshooting](#troubleshooting) -9. [Performance Tips](#performance-tips) - -## Build System Architecture - -### Individual Kernel Compilation (New Approach) - -The new tile engine benchmark system compiles each kernel configuration into a separate executable. This provides: -- Better build parallelism -- Faster incremental builds -- More targeted testing -- Easier debugging of specific configurations - -Each benchmark executable follows the naming pattern: -``` -benchmark_gemm____ -``` - -### Monolithic Build (Legacy Approach) - -The original system compiles all kernels into a single executable (`benchmark_gemm_[Datatype]_[Layout]`), which can then be filtered at runtime using command-line arguments. - -## Build Instructions - -### Prerequisites -- ROCm installation -- CMake 3.16 or higher -- C++17 compatible compiler - -### Basic Build - -```bash -# In the root of composable kernel, create build directory -mkdir build && cd build - -# Configure with specific datatypes and layouts -# Replace [Arch] with your GPU architecture (e.g., gfx90a, gfx942) -# Replace [Datatype1;Datatype2;...] with datatypes (fp8, bf8, int8, fp16, bf16, fp32, fp64) -# Replace [Layout1;Layout2;...] with layouts (rcr, rrr, crr, ccr) -../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]" -DGEMM_LAYOUT="[Layout1;Layout2]" - -# Build specific benchmarks -make benchmark_gemm_[Datatype1]_[Layout1] -j -``` - -### Configuration Options - -The build system supports several configuration options: - -#### Using Custom Config Files -```bash -# Method 1: CMake variable (config file must be in configs/ directory) -cmake -DGEMM_CONFIG_FILE=my_custom_config.json ... - -# Method 2: Environment variable (takes precedence over CMake variable) -export GEMM_CONFIG_FILE=my_custom_config.json -cmake ... -``` - -#### Config File Priority Order -1. **Environment variable** `GEMM_CONFIG_FILE` (highest priority) -2. **CMake variable** `GEMM_CONFIG_FILE` -3. **Default config** (default_config.json for all layouts) - -**Note**: All custom config files must be placed in the `tile_engine/ops/gemm/configs/` directory. - -### Example Build Commands - -```bash -# Build for gfx942 with fp8 and fp16 datatypes, rcr layout -mkdir build && cd build -../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16" -DGEMM_LAYOUT="rcr;ccr;rrr;crr" -make benchmark_gemm_fp8_rcr -j -make benchmark_gemm_fp16_rcr -j -``` - -### Building Individual Kernels - -```bash -# Build a specific kernel configuration -make benchmark_gemm_fp8_rcr_compv4_default_intrawave_False_False_False_False_256x256x32_1x4x1_32x32x32 - -# Build all fp16 benchmarks in parallel -make -j$(nproc) $(make help | grep benchmark_gemm_fp16 | awk '{print $2}') -``` - -### Rebuilding After Configuration Changes - -If you modify the configuration file, you must rebuild: -```bash -rm -rf tile_engine/ && make benchmark_gemm_[Datatype]_[Layout] -j -``` - -## Running Benchmarks - -### Individual Kernel Execution - -```bash -cd /path/to/build/directory -./bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 \ - -m=512 -n=512 -k=512 -verify=1 -``` - -### Monolithic Executable (Legacy) - -```bash -# Run specific pipeline/scheduler/epilogue combination -./bin/benchmark_gemm_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=default -``` - -### Automated Testing - -Use the provided test script to run multiple benchmarks: -```bash -cd /path/to/composable_kernel/tile_engine/ops/gemm -./test_benchmark.sh [build_directory] -``` - -## Configuration System - -### Configuration Files - -The system uses JSON configuration files to specify kernel parameters: - -- `configs/default_config.json` - Default configurations for various datatypes -- `configs/user_provided_config.json` - User-customizable configurations - -### Configuration Structure - -```json -{ - "tile_config": { - "tile_m": {"values": [256, 128]}, - "tile_n": {"values": [256, 128]}, - "tile_k": {"values": [64, 32]}, - "warp_m": {"values": [2, 4]}, - "warp_n": {"values": [2, 1]}, - "warp_k": {"values": [1]}, - "warp_tile_m": {"values": [32, 16]}, - "warp_tile_n": {"values": [32, 16]}, - "warp_tile_k": {"values": [16, 32]} - }, - "trait_config": { - "pipeline": {"values": ["compv3", "compv4", "mem"]}, - "scheduler": {"values": ["intrawave", "interwave"]}, - "epilogue": {"values": ["default", "cshuffle"]}, - "pad_m": {"values": [false]}, - "pad_n": {"values": [false]}, - "pad_k": {"values": [false]}, - "persistent": {"values": [false]} - } -} -``` - -## Scripts and Tools - -### Python Scripts - -#### gemm_instance_builder.py -**Purpose**: Main kernel instance generation script that creates C++ kernel implementations based on configuration files. - -**Key Features**: -- Generates individual kernel header files for separate compilation -- Supports multiple data types (fp16, fp8, bf16, fp32, fp64) -- Validates tile configurations for correctness -- Creates CMake integration files - -**Usage**: -```bash -python gemm_instance_builder.py \ - --working_path ./generated \ - --datatype fp16 \ - --layout rcr \ - --config_json configs/user_provided_config.json \ - --gen_all_individual -``` - -#### gemm_instance_builder_parallel.py -**Purpose**: Parallel version of the instance builder for faster generation of multiple kernel configurations. - -**Features**: -- Multi-threaded kernel generation -- Improved performance for large configuration spaces - -#### validation_utils.py -**Purpose**: Provides comprehensive validation functions for kernel configurations. - -**Key Functions**: -- `is_tile_config_valid()` - Validates tile dimensions and alignments -- `is_trait_combination_valid()` - Checks if pipeline/epilogue/scheduler combinations are supported -- `validate_warp_tile_combination()` - GPU-specific warp tile validation -- `validate_lds_capacity()` - Ensures configurations fit in LDS memory - -**Validation Checks**: -- Dimension alignment (tile dimensions must be divisible by warp dimensions) -- LDS capacity constraints -- GPU-specific warp tile support -- Unsupported trait combinations - -#### test_validation.py -**Purpose**: Test suite for the validation logic to ensure correctness. - -**Usage**: -```bash -python test_validation.py -``` - -**Tests**: -- Warp tile combination validation -- Trait combination validation -- Full tile configuration validation - -#### gemm_benchmark.py -**Purpose**: Python script for running and analyzing GEMM benchmarks. - -**Features**: -- Automated benchmark execution -- Performance data collection -- Result analysis and reporting - -#### json_config.py -**Purpose**: Configuration file parsing and management. - -**Features**: -- JSON configuration loading -- Default configuration handling -- Configuration validation - -#### codegen_utils.py -**Purpose**: Utility functions for code generation. - -**Features**: -- Template processing -- Code formatting utilities -- File generation helpers - -### Shell Scripts - -#### test_benchmark.sh -**Purpose**: Automated benchmark testing script that finds and runs all built benchmark executables. - -**Features**: -- Automatic build directory detection -- Batch execution of multiple benchmarks -- CSV result collection -- Colored output for easy reading -- Example command generation - -**Usage**: -```bash -# Auto-detect build directory -./test_benchmark.sh - -# Specify build directory -./test_benchmark.sh /path/to/build/directory -``` - -**What it does**: -1. Finds all benchmark executables in the build directory -2. Runs each with multiple problem sizes (512, 1024, 2048) -3. Performs GPU verification -4. Saves results to timestamped CSV file -5. Provides summary statistics - -## Command Line Options - -All benchmark executables support the following options: - -### Matrix Dimensions -- `-m=` - M dimension (default: 3840) -- `-n=` - N dimension (default: 4096) -- `-k=` - K dimension (default: 2048) - -### Strides -- `-stride_a=` - Stride for matrix A (default: 0, auto-calculated) -- `-stride_b=` - Stride for matrix B (default: 0, auto-calculated) -- `-stride_c=` - Stride for matrix C (default: 0, auto-calculated) - -### Verification -- `-verify=<0|1|2>` - Verification mode - - 0: No verification (default) - - 1: CPU verification - - 2: GPU verification - -### Performance Testing -- `-warmup=` - Warmup iterations (default: 50) -- `-repeat=` - Benchmark iterations (default: 100) -- `-timer=` - Use GPU timer (default: true) -- `-flush_cache=` - Flush cache between runs (default: true) -- `-rotating_count=` - Cache rotation count (default: 1000) - -### Initialization -- `-init=<0|1|2>` - Tensor initialization method - - 0: Random values [-1, 1] (default) - - 1: Linear sequence (i % 17) - - 2: Constant value (1.0) - -### Output Options -- `-log=` - Enable verbose logging (default: false) -- `-metric=<0|1|2>` - Performance metric - - 0: Latency in ms (default) - - 1: TFLOPS - - 2: Bandwidth in GB/s -- `-json_output=` - JSON format output (default: false) -- `-csv_filename=` - Save results to CSV -- `-csv_format=` - CSV format (default: comprehensive) - -### Advanced Options -- `-split_k=` - Split-K factor (default: 1) -- `-structured_sparsity=` - Enable structured sparsity (default: false) -- `-pipeline=` - Pipeline type (default: compv3) -- `-scheduler=` - Scheduler type (default: intrawave) -- `-epilogue=` - Epilogue type (default: cshuffle) -- `-pad_m=` - Pad M dimension (default: false) -- `-pad_n=` - Pad N dimension (default: false) -- `-pad_k=` - Pad K dimension (default: false) -- `-persistent=` - Use persistent kernel (default: false) - -## Understanding Kernel Names - -The kernel naming convention encodes the configuration: - -``` -benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 - ^^^^ ^^^ ^^^^^^ ^^^^^^^ ^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^ ^^^^^^^ ^^^^^^^^^ - | | | | | | | | | - | | | | | Padding & flags | | Warp tile - | | | | Scheduler | Thread tile - | | | Epilogue Block tile - | | Pipeline - | Layout (Row-Column-Row) - Data type -``` - -### Components: -- **Data type**: fp16, fp32, bf16, fp8, bf8, int8 -- **Layout**: rcr (Row-Column-Row), rrr, crr, ccr -- **Pipeline**: mem, compv3, compv4 -- **Epilogue**: default, cshuffle -- **Scheduler**: intrawave, interwave -- **Flags**: pad_m, pad_n, pad_k, persistent (4 boolean flags) -- **Tile sizes**: BlockTile x ThreadTile x WarpTile - -## Troubleshooting - -### Common Issues - -1. **Kernel not found** - - Ensure the specific benchmark executable is built - - Check the build directory bin/ folder - -2. **Verification failures** - - Try GPU verification (-verify=2) which may be more accurate - - Check data type compatibility - - Verify stride calculations - -3. **Build failures** - - Check GPU architecture compatibility - - Ensure ROCm is properly installed - - Verify configuration file syntax - -4. **Performance variations** - - Increase warmup iterations - - Disable CPU frequency scaling - - Use GPU timer for accurate measurements - -### Debug Options - -Enable verbose logging: -```bash -./bin/benchmark_gemm_... -log=true -verify=1 -``` - -Test validation logic: -```bash -python test_validation.py -``` - -## Performance Tips - -1. **Optimal Problem Sizes**: Use sizes that are multiples of tile dimensions -2. **Warmup**: Use at least 50-100 warmup iterations -3. **GPU Timer**: Always use `-timer=true` for accurate measurements -4. **Cache Management**: Enable cache flushing for consistent results -5. **Thread Affinity**: Set CPU affinity to reduce variation - -## Integration Examples - -### Python Integration - -```python -import subprocess -import json - -# Run benchmark with JSON output -result = subprocess.run([ - './bin/benchmark_gemm_fp16_rcr_...', - '-m=1024', '-n=1024', '-k=1024', - '-json_output=true' -], capture_output=True, text=True) - -# Parse results -data = json.loads(result.stdout) -print(f"Performance: {data['tflops']} TFLOPS") -``` - -### Batch Testing Script - -```bash -#!/bin/bash -SIZES="512 1024 2048 4096" -for size in $SIZES; do - echo "Testing ${size}x${size}x${size}" - ./bin/benchmark_gemm_... -m=$size -n=$size -k=$size \ - -verify=2 -csv_filename=results.csv -done -``` - -## Contributing - -When adding new features or configurations: -1. Update validation logic in `validation_utils.py` -2. Add tests to `test_validation.py` -3. Update configuration examples -4. Document new command-line options - -For more information about the Composable Kernel project, visit the main repository documentation. diff --git a/tile_engine/ops/gemm/gemm_common.hpp b/tile_engine/ops/gemm/gemm_common.hpp deleted file mode 100644 index 1fdc63b33b..0000000000 --- a/tile_engine/ops/gemm/gemm_common.hpp +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" -#include "ck_tile/core/numeric/integer.hpp" -#include "ck_tile/core/numeric/pk_int4.hpp" - -// Helper function to determine if a layout is row-major -template -constexpr auto is_row_major(Layout) -{ - return ck_tile::bool_constant>{}; -} - -// Structure to hold kernel traits for dispatcher -struct KernelTraits -{ - std::string pipeline; // compv3, compv4, mem - std::string scheduler; // intrawave, interwave - std::string epilogue; // cshuffle, default - bool pad_m; - bool pad_n; - bool pad_k; - bool persistent; - - // Constructor with defaults - KernelTraits() - : pipeline("compv3"), - scheduler("intrawave"), - epilogue("cshuffle"), - pad_m(false), - pad_n(false), - pad_k(false), - persistent(false) - { - } -}; diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 65fede6a5f..27ca805c2e 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -1,17 +1,12 @@ -#!/usr/bin/env python # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT - import os import json -import argparse -import itertools -import multiprocessing -import concurrent.futures from pathlib import Path -import logging import importlib.util +import itertools +import logging def _import_validation_utils(): @@ -22,7 +17,7 @@ def _import_validation_utils(): # Load the module dynamically spec = importlib.util.spec_from_file_location( "validation_utils", - os.path.join(parent_dir, "commons", "gemm_validation_utils.py"), + os.path.join(parent_dir, "gemm", "gemm_validation_utils.py"), ) validation_utils = importlib.util.module_from_spec(spec) spec.loader.exec_module(validation_utils) @@ -34,14 +29,22 @@ def _import_validation_utils(): _validation_utils = _import_validation_utils() is_tile_config_valid = _validation_utils.is_tile_config_valid is_trait_combination_valid = _validation_utils.is_trait_combination_valid -get_dtype_string = _validation_utils.get_dtype_string get_abc_layouts = _validation_utils.get_abc_layouts - -logging.basicConfig(level=logging.INFO) +get_abcd_layouts = _validation_utils.get_abcd_layouts +get_dtype_string = _validation_utils.get_dtype_string class GemmKernelBuilder: - def __init__(self, working_path, gpu_target, datatype, layout, config_json=None): + def __init__( + self, + kernel_name_prefix, + working_path, + gpu_target, + datatype, + layout, + config_json=None, + ): + self.kernel_name_prefix = kernel_name_prefix self.working_path = Path(working_path) self.gpu_target = gpu_target self.datatype = datatype @@ -56,10 +59,10 @@ class GemmKernelBuilder: with open(config_json, "r") as f: self.config = json.load(f) - def write_kernel_list(self): + def _list_kernels(self): """Write kernel list to file for CMake to read (with comprehensive validation)""" # Get configurations using comprehensive validation - tile_configs = self._get_tile_configs(fast_mode=False) + tile_configs = self._get_tile_configs() trait_combos = self._generate_trait_combinations() kernel_list = [] @@ -76,7 +79,7 @@ class GemmKernelBuilder: ) = trait_combo # Create kernel name with proper boolean capitalization - kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" + kernel_name = f"{self.kernel_name_prefix}_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" # Create tile configuration string tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" @@ -94,11 +97,15 @@ class GemmKernelBuilder: ) # Write kernel count - with open(self.working_path / "gemm_kernel_count.txt", "w") as f: + with open( + self.working_path / f"{self.kernel_name_prefix}_kernel_count.txt", "w" + ) as f: f.write(str(len(kernel_list))) # Write kernel list - with open(self.working_path / "gemm_kernel_list.txt", "w") as f: + with open( + self.working_path / f"{self.kernel_name_prefix}_kernel_list.txt", "w" + ) as f: for kernel in kernel_list: # Format: kernel_name|tile_config|trait_combo tile_config = kernel["tile_config"] @@ -117,8 +124,9 @@ class GemmKernelBuilder: print(f"Listed {len(kernel_list)} kernel configurations") - def _get_tile_configs(self, fast_mode=False): + def _get_tile_configs(self): """Get tile configurations for the current datatype and layout""" + tile_config = self.config["tile_config"] # Generate values in the config if default range is given @@ -153,6 +161,14 @@ class GemmKernelBuilder: warp_tile_k_values = tile_config.get("warp_tile_k").get("values") # Generate all combinations + default_pipeline = "" + if self.kernel_name_prefix == "gemm_universal": + default_pipeline = "compv4" + elif self.kernel_name_prefix == "gemm_multi_d": + default_pipeline = "compv4" + elif self.kernel_name_prefix == "gemm_preshuffle": + default_pipeline = "preshufflev2" + configs = [] for tile_m in tile_m_values: for tile_n in tile_n_values: @@ -174,7 +190,7 @@ class GemmKernelBuilder: warp_tile_m, warp_tile_n, warp_tile_k, - fast_mode=fast_mode, + default_pipeline, ): configs.append( { @@ -211,59 +227,47 @@ class GemmKernelBuilder: warp_tile_m, warp_tile_n, warp_tile_k, - pipeline="compv4", # Default pipeline for validation - fast_mode=False, # Add fast mode option + pipeline, ): """Validate that tile configuration is reasonable""" - if fast_mode: - # Fast validation for listing - only basic sanity checks - if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: - return False - if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: - return False - if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: + # Validate preshuffle specific constraints + if ( + self.config.get("permute_n") is not None + and self.config.get("permute_n") is True + ): + valid = (tile_n / warp_tile_n / warp_n) % 2 == 0 + if not valid: return False - # Basic divisibility check - if tile_m % (warp_m * warp_tile_m) != 0: - return False - if tile_n % (warp_n * warp_tile_n) != 0: - return False - if tile_k % (warp_k * warp_tile_k) != 0: - return False + # Determine data types for validation + a_datatype = self.datatype + b_datatype = self.datatype + c_datatype = self.datatype - return True - else: - # Full validation for generation - # Determine data types for validation - a_datatype = self.datatype - b_datatype = self.datatype - c_datatype = self.datatype + layout = self.layout - layout = self.layout + # Special handling for certain data types + if self.datatype in ["fp8", "bf8"]: + c_datatype = "fp16" - # Special handling for certain data types - if self.datatype in ["fp8", "bf8"]: - c_datatype = "fp16" - - # Use the comprehensive validation function - return is_tile_config_valid( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - a_datatype, - b_datatype, - c_datatype, - pipeline, - layout, - self.gpu_target, - ) + # Use the comprehensive validation function + return is_tile_config_valid( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + pipeline, + layout, + self.gpu_target, + ) def _generate_trait_combinations(self): """Generate all combinations of traits""" @@ -302,10 +306,11 @@ class GemmKernelBuilder: ) return combinations - def _generate_kernel_instance( - self, tile_config, trait_combo, k_block_per_cu, is_header=True - ): + def _generate_kernel_instance(self, tile_config, trait_combo): """Generate a single kernel instance""" + + k_block_per_cu = self.config.get("k_block_per_cu", 1) + ( pipeline, epilogue, @@ -317,7 +322,7 @@ class GemmKernelBuilder: ) = trait_combo # Create kernel name with proper boolean capitalization - kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" + kernel_name = f"{self.kernel_name_prefix}_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" # Create tile configuration string tile_str = ( @@ -330,35 +335,71 @@ class GemmKernelBuilder: kernel_name += f"_{tile_str}" - # Map pipeline names to the correct pipeline implementation - pipeline_impl_map = { - "mem": "ck_tile::GemmPipelineAgBgCrMem", - "compv3": "ck_tile::GemmPipelineAgBgCrCompV3", - "compv4": "ck_tile::GemmPipelineAgBgCrCompV4", - } + if self.kernel_name_prefix in ["gemm_universal", "gemm_multi_d"]: + # Map pipeline names to the correct pipeline implementation + pipeline_impl_map = { + "mem": "ck_tile::GemmPipelineAgBgCrMem", + "compv3": "ck_tile::GemmPipelineAgBgCrCompV3", + "compv4": "ck_tile::GemmPipelineAgBgCrCompV4", + } + # Map pipeline names to base pipeline for hot loop detection + base_pipeline_map = { + "mem": "ck_tile::BaseGemmPipelineAgBgCrMem", + "compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3", + "compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4", + } + elif self.kernel_name_prefix == "gemm_preshuffle": + # Map pipeline names to the correct pipeline implementation + pipeline_impl_map = { + "preshufflev2": "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2", + } + # Map pipeline names to base pipeline for hot loop detection + base_pipeline_map = { + "preshufflev2": "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2", + } - # Map scheduler names to the correct enum values scheduler_type_map = { "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", "interwave": "ck_tile::GemmPipelineScheduler::Interwave", "default": "ck_tile::GemmPipelineScheduler::Default", } - # Determine accumulator type based on datatype - acc_type = "float" + instance_code = self.populate_kernel_header(kernel_name) + instance_code += self.populate_kernel_dtype_layout() + instance_code += self.populate_strut_begin(kernel_name) + instance_code += self.populate_tile_config(tile_config) + instance_code += self.populate_trait_config(trait_combo) + instance_code += self.populate_initialization(base_pipeline_map, pipeline) + instance_code += self.populate_launch( + scheduler_type_map, + scheduler, + pipeline_impl_map, + pipeline, + epilogue, + k_block_per_cu, + persistent, + ) - # Determine output type - c_type = self.datatype - if self.datatype in ["fp8", "bf8"]: - c_type = "fp16" + # Write into a file + simplified_name = kernel_name + if simplified_name.startswith(f"{self.kernel_name_prefix}_"): + simplified_name = simplified_name[len(self.kernel_name_prefix) + 1 :] - # Determine layouts based on self.layout - a_layout, b_layout, c_layout = get_abc_layouts(self.layout) + header_file = ( + self.working_path + / f"{self.kernel_name_prefix}_single_{simplified_name}.hpp" + ) + with open(header_file, "w") as f: + f.write(instance_code) - # Generate kernel instance code using the correct API - pragma_line = "#pragma once\n" if is_header else "" + print(f"Generated {header_file}") + + return kernel_name, instance_code + + def populate_kernel_header(self, kernel_name): instance_code = f"""// Generated kernel instance for {kernel_name} -{pragma_line} +#pragma once + #include #include #include @@ -369,22 +410,66 @@ class GemmKernelBuilder: #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" +""" + return instance_code + def populate_kernel_dtype_layout(self): + # Determine accumulator type based on datatype + acc_type = "float" + + # Determine output type + c_type = self.datatype + if self.datatype in ["fp8", "bf8"]: + c_type = "fp16" + + # Assign layouts based on self.layout + if self.kernel_name_prefix == "gemm_multi_d": + a_layout, b_layout, c_layout, ds_layout = get_abcd_layouts(self.layout) + elif ( + self.kernel_name_prefix == "gemm_universal" + or self.kernel_name_prefix == "gemm_preshuffle" + ): + a_layout, b_layout, c_layout = get_abc_layouts(self.layout) + + instance_code = f""" using ADataType = {get_dtype_string(self.datatype)}; using BDataType = {get_dtype_string(self.datatype)}; using AccDataType = {acc_type}; -using CDataType = {get_dtype_string(c_type)}; +using CDataType = {get_dtype_string(c_type)};""" + if self.kernel_name_prefix == "gemm_multi_d": + instance_code += f""" +using D0DataType = {get_dtype_string(self.datatype)}; +using D1DataType = {get_dtype_string(self.datatype)}; +using DsDataType = ck_tile::tuple;""" + + instance_code += f""" using ALayout = {a_layout}; using BLayout = {b_layout}; using CLayout = {c_layout}; +""" + if self.kernel_name_prefix == "gemm_multi_d": + instance_code += f""" +using D0Layout = {ds_layout[0]}; +using D1Layout = {ds_layout[1]}; +using DsLayout = ck_tile::tuple; +using ElementWiseFn = ck_tile::element_wise::{self.elementwise_function};""" + + return instance_code + + def populate_strut_begin(self, kernel_name): + instance_code = f""" // Kernel name for display constexpr const char* KERNEL_NAME = "{kernel_name}"; // Wrapper for simplified launch interface struct SelectedKernel {{ - // Tile configuration + """ + return instance_code + + def populate_tile_config(self, tile_config): + instance_code = f"""// Tile configuration static constexpr ck_tile::index_t BlockSize = 256; static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]}; static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]}; @@ -394,34 +479,187 @@ struct SelectedKernel {{ static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]}; static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]}; static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]}; - static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]}; + static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]};""" + return instance_code - // Traits + def populate_trait_config(self, trait_combo): + ( + pipeline, + epilogue, + scheduler, + pad_m, + pad_n, + pad_k, + persistent, + ) = trait_combo + + instance_code = f""" + + // Traits configurations static constexpr bool kPadM = {"true" if pad_m in [True, "true"] else "false"}; static constexpr bool kPadN = {"true" if pad_n in [True, "true"] else "false"}; static constexpr bool kPadK = {"true" if pad_k in [True, "true"] else "false"}; static constexpr bool TransposeC = false; + static constexpr bool DoubleSmemBuffer = {"true" if pipeline in ["compv4", "preshufflev2"] else "false"};""" + + if self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]: + instance_code += f""" static constexpr bool UsePersistentKernel = {"true" if persistent in [True, "true"] else "false"}; - static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"}; static constexpr bool UseStructuredSparsity = false; - static constexpr bool Preshuffle = false; - static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr ck_tile::index_t NumWaveGroups = 1;""" + + if self.kernel_name_prefix == "gemm_preshuffle": + instance_code += f""" + static constexpr bool Preshuffle = true; + static constexpr bool PermuteN = {"true" if self.config.get("permute_n") else "false"};""" + else: + instance_code += """ + static constexpr bool Preshuffle = false;""" + return instance_code + + def populate_initialization(self, base_pipeline_map, pipeline): + # Tile Shape + if self.kernel_name_prefix == "gemm_multi_d": + instance_code = """ // Tile shape using TileShape = ck_tile::TileGemmShape< ck_tile::sequence, ck_tile::sequence, - ck_tile::sequence, - false, false>; - - // Tile partitioner - using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; + ck_tile::sequence>;""" - static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ - - const auto Run = [&](const auto memory_operation_) {{ - constexpr auto scheduler = {scheduler_type_map.get(scheduler)}; - [[maybe_unused]] constexpr auto memory_operation = memory_operation_.value; + elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]: + instance_code = """ + + // Tile shape + using TileShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence, + false, false>;""" + + # Tile partitioner + instance_code += """ + + // Tile partitioner + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner;""" + + # Traits + if self.kernel_name_prefix == "gemm_multi_d": + instance_code += """ + + // Traits + using Traits = ck_tile::TileGemmTraits;""" + elif self.kernel_name_prefix == "gemm_preshuffle": + instance_code += """ + + // Traits + using Traits = ck_tile::TileGemmTraits;""" + + # Pipeline problem + if self.kernel_name_prefix in ["gemm_preshuffle", "gemm_multi_d"]: + instance_code += """ + + // Pipeline problem + using GemmPipelineProblem = ck_tile::GemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + Traits>;""" + + # Base pipeline for hot loop detection + if self.kernel_name_prefix == "gemm_preshuffle": + instance_code += f""" + + // Base pipeline for hot loop detection + using BaseGemmPipeline = {base_pipeline_map.get(pipeline, "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2")};""" + + elif self.kernel_name_prefix == "gemm_multi_d": + instance_code += f""" + + // Base pipeline for hot loop detection + using BaseGemmPipeline = {base_pipeline_map.get(pipeline)};""" + + return instance_code + + def populate_launch( + self, + scheduler_type_map, + scheduler, + pipeline_impl_map, + pipeline, + epilogue, + k_block_per_cu, + persistent, + ): + # Function Signature + if self.kernel_name_prefix == "gemm_multi_d": + instance_code = """ + + // Launch function + static float launch(const ck_tile::GemmMultiDHostArgs& args, const ck_tile::stream_config& stream) {""" + elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]: + instance_code = """ + + // Launch function + static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {""" + + # Scheduler initialization + if self.kernel_name_prefix in ["gemm_preshuffle", "gemm_multi_d"]: + instance_code += f""" + + constexpr auto scheduler = {scheduler_type_map.get(scheduler)};""" + + # Problem Initialization + if self.kernel_name_prefix == "gemm_preshuffle": + instance_code += """ + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + ck_tile::TileGemmUniversalTraits, + scheduler>;""" + elif self.kernel_name_prefix == "gemm_multi_d": + instance_code += """ + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + ck_tile::TileGemmUniversalTraits, + scheduler>;""" + + # GemmPipeline + if self.kernel_name_prefix in ["gemm_preshuffle", "gemm_multi_d"]: + instance_code += f""" + + using GemmPipeline = {pipeline_impl_map.get(pipeline)};""" + + # Runfunction body + instance_code += """ + + const auto Run = [&](const auto memory_operation_) {""" + + # Scheduler initialization + if self.kernel_name_prefix in ["gemm_universal"]: + instance_code += f""" + constexpr auto scheduler = {scheduler_type_map.get(scheduler)};""" + + # Memory operation + instance_code += """ + [[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;""" + + # UniversalGemmProblem + if self.kernel_name_prefix in ["gemm_universal"]: + instance_code += """ using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< ADataType, @@ -432,16 +670,130 @@ struct SelectedKernel {{ ALayout, BLayout, CLayout, TransposeC, UseStructuredSparsity, UsePersistentKernel, NumWaveGroups, Preshuffle>, - scheduler>; - - using GemmPipeline = {pipeline_impl_map.get(pipeline)}; - - // Epilogue -""" + scheduler>;""" + + # GemmPipeline + if self.kernel_name_prefix in ["gemm_universal"]: + instance_code += f""" + + using GemmPipeline = {pipeline_impl_map.get(pipeline)};""" + + # Epilogue + instance_code += self.populate_epilogue(epilogue) + + # Kernel type + if self.kernel_name_prefix == "gemm_multi_d": + instance_code += """ + + // Kernel type + using GemmKernelMultiD = ck_tile::GemmKernelMultiD; + + // Kernel arguments + auto kargs = GemmKernelMultiD::MakeKernelArgs(args); + + if (!GemmKernelMultiD::IsSupportedArgument(kargs)) { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); + } + + // Get grid and block sizes + const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = GemmKernelMultiD::BlockSize(); + + if(stream.log_level_ > 0) { + std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + }""" + + instance_code += f""" + // Launch kernel + constexpr int kBlockPerCu = {k_block_per_cu}; + float ave_time = ck_tile::launch_kernel( + stream, + ck_tile::make_kernel(GemmKernelMultiD{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }};""" + + elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]: + instance_code += f""" + + // Kernel type + using GemmKernel = ck_tile::GemmKernel; + + // Kernel arguments + auto kargs = GemmKernel::MakeKernelArgs(args); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); + }} + + // Get grid and block sizes + const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; + const dim3 blocks = GemmKernel::BlockSize(); + + if(stream.log_level_ > 0) {{ + std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n' + << "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" + << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" + << std::endl; + }}""" + + instance_code += f""" + // Launch kernel + constexpr int kBlockPerCu = {k_block_per_cu}; + float ave_time = ck_tile::launch_kernel( + stream, + ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }};""" + + # Run SplitK handler + + instance_code += """ + + float ave_time = 0.f; + if(args.k_batch == 1) { + ave_time = Run(ck_tile::integral_constant{}); + } else { + ave_time = Run(ck_tile::integral_constant{}); + } + return ave_time; + } +}; +""" + return instance_code + + def populate_epilogue(self, epilogue): + instance_code = """ + + // Epilogue + """ - # Add epilogue configuration based on type if epilogue == "cshuffle": - instance_code += """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< + if self.kernel_name_prefix == "gemm_universal": + instance_code += self.populate_cshuffle_gemm_universal() + elif self.kernel_name_prefix == "gemm_multi_d": + instance_code += self.populate_cshuffle_gemm_multi_d() + elif self.kernel_name_prefix == "gemm_preshuffle": + instance_code += self.populate_cshuffle_gemm_preshuffle() + else: # default epilogue + if self.kernel_name_prefix == "gemm_universal": + instance_code += self.populate_default_gemm_universal() + elif self.kernel_name_prefix == "gemm_multi_d": + instance_code += self.populate_default_gemm_multi_d() + elif self.kernel_name_prefix == "gemm_preshuffle": + instance_code += self.populate_default_gemm_preshuffle() + + return instance_code + + def populate_cshuffle_gemm_universal(self): + instance_code = """ + using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< ADataType, BDataType, ck_tile::tuple<>, // DsDataType @@ -461,10 +813,64 @@ struct SelectedKernel {{ memory_operation, // MemoryOperation_ NumWaveGroups>; // kNumWaveGroups_ - using GemmEpilogue = ck_tile::CShuffleEpilogue; -""" - else: # default epilogue - instance_code += """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< + using GemmEpilogue = ck_tile::CShuffleEpilogue;""" + return instance_code + + def populate_cshuffle_gemm_multi_d(self): + instance_code = """ + using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< + ADataType, + BDataType, + DsDataType, + AccDataType, + CDataType, + DsLayout, + CLayout, + ElementWiseFn, + TilePartitioner::MPerBlock, // kM_ + TilePartitioner::NPerBlock, // kN_ + WarpPerBlock_M, // MWave_ + WarpPerBlock_N, // NWave_ + WarpTileM, // MPerXdl_ + WarpTileN, // NPerXdl_ + WarpTileK, // KPerXdl_ + TransposeC, // isCTransposed_ + memory_operation>; // MemoryOperation_ + + using GemmEpilogue = ck_tile::CShuffleEpilogue;""" + return instance_code + + def populate_cshuffle_gemm_preshuffle(self): + instance_code = """ + using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, // kM_ + TilePartitioner::NPerBlock, // kN_ + WarpPerBlock_M, // MWave_ + WarpPerBlock_N, // NWave_ + WarpTileM, // MPerXdl_ + WarpTileN, // NPerXdl_ + WarpTileK, // KPerXdl_ + TransposeC, // isCTransposed_ + memory_operation, // MemoryOperation_ + NumWaveGroups, // kNumWaveGroups_ + false, // FixedVectorSize_ + 1, // VectorSizeC_ + PermuteN>; // isPermuteN_ + + using GemmEpilogue = ck_tile::CShuffleEpilogue;""" + return instance_code + + def populate_default_gemm_universal(self): + instance_code = """ + using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< ADataType, BDataType, ck_tile::tuple<>, // DsDataType @@ -482,151 +888,60 @@ struct SelectedKernel {{ WarpTileK, // kKPerXdl_ TransposeC>; // isCTransposed_ - using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue; -""" + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" + return instance_code - instance_code += f""" + def populate_default_gemm_multi_d(self): + instance_code = """ + using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< + ADataType, + BDataType, + DsDataType, + AccDataType, + CDataType, + DsLayout, + CLayout, + ElementWiseFn, + TilePartitioner::MPerBlock, // kM_ + TilePartitioner::NPerBlock, // kN_ + kPadM, + kPadN, + WarpTileM, // kMPerXdl_ + WarpTileN, // kNPerXdl_ + WarpTileK, // kKPerXdl_ + TransposeC>; // isCTransposed_ - // Kernel type - using GemmKernel = ck_tile::GemmKernel; + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" + return instance_code + + def populate_default_gemm_preshuffle(self): + instance_code = """ + using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, // kM_ + TilePartitioner::NPerBlock, // kN_ + kPadM, + kPadN, + WarpTileM, // kMPerXdl_ + WarpTileN, // kNPerXdl_ + WarpTileK, // kKPerXdl_ + TransposeC>; // isCTransposed_ - // Make kernel arguments - auto kargs = GemmKernel::MakeKernelArgs(args); - - if (!GemmKernel::IsSupportedArgument(kargs)) {{ - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); - }} - - // Get grid and block sizes - const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; - const dim3 blocks = GemmKernel::BlockSize(); - - if(stream.log_level_ > 0) {{ - std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n' - << "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" - << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" - << std::endl; - }} - - // Launch kernel - constexpr int kBlockPerCu = {k_block_per_cu}; - float ave_time = ck_tile::launch_kernel( - stream, - ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); - - return ave_time; - }}; - - float ave_time = 0.f; - - if(args.k_batch == 1) {{ - ave_time = Run(ck_tile::integral_constant{{}}); - }} else {{ - ave_time = Run(ck_tile::integral_constant{{}}); - }} - - return ave_time; - }} -}}; -""" - return kernel_name, instance_code - - def run(self, num_workers=None): - """Run the builder to generate individual kernel files""" - # Generate individual kernel files - self.generate_individual(num_workers) - - def generate_individual(self, num_workers=None): - """Generate individual kernel files for separate compilation with parallel processing""" - if num_workers is None: - num_workers = min( - multiprocessing.cpu_count(), 8 - ) # Limit to avoid memory issues - - tile_configs = self._get_tile_configs() - trait_combos = self._generate_trait_combinations() - k_block_per_cu = self.config.get("k_block_per_cu") - if k_block_per_cu is None: - k_block_per_cu = 1 - - # Prepare work items for parallel processing - work_items = [] - for tile_config in tile_configs: - for trait_combo in trait_combos: - work_items.append( - ( - tile_config, - trait_combo, - k_block_per_cu, - self.working_path, - self.gpu_target, - self.datatype, - self.layout, - self.config_json, - ) - ) - print( - f"Generating {len(work_items)} individual kernel files using {num_workers} workers..." - ) - print(f" Tile configs: {len(tile_configs)}") - print(f" Trait combinations: {len(trait_combos)}") - print(f" Total kernels: {len(work_items)}") - - # Show first few work items for debugging - if work_items: - print(" First work item example:") - tile_config, trait_combo = work_items[0][:2] - print(f" Tile config: {tile_config}") - print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits - - # Process work items in parallel - kernel_list = [] - completed = 0 - - with concurrent.futures.ProcessPoolExecutor( - max_workers=num_workers - ) as executor: - # Submit all work items - print(f" Submitting {len(work_items)} tasks to executor...") - future_to_item = { - executor.submit(_generate_single_kernel_individual, item): item - for item in work_items - } - print(" All tasks submitted, waiting for completion...") - - # Collect results with progress reporting - for future in concurrent.futures.as_completed(future_to_item): - completed += 1 - if completed % 100 == 0 or completed == len(work_items): - print( - f" Progress: {completed}/{len(work_items)} kernels generated" - ) - try: - result = future.result() - if result: - kernel_list.append(result) - except Exception as exc: - item = future_to_item[future] - print(f"Kernel generation failed for {item}: {exc}") - - # Sort kernel list for consistent ordering - kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name - - # Generate CMake include file for individual targets - self._generate_cmake_individual_targets(kernel_list) - - print( - f"Generated {len(kernel_list)} individual kernel files in {self.working_path}" - ) + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" + return instance_code def _generate_cmake_individual_targets(self, kernel_list): """Generate CMake include file that creates individual targets""" - cmake_code = f"""# Generated CMake file for individual GEMM targets -# Datatype: {self.datatype}, Layout: {self.layout} - -""" + cmake_code = f"""# Generated CMake file for individual {self.kernel_name_prefix} targets + # Datatype: {self.datatype}, Layout: {self.layout} + """ for kernel_name, trait_combo, tile_config in kernel_list: pipeline, epilogue, scheduler = trait_combo[:3] @@ -640,187 +955,11 @@ struct SelectedKernel {{ str(x) for x in trait_combo[3:] ) - cmake_code += f'create_individual_gemm_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n' + cmake_code += f'create_individual_{self.kernel_name_prefix}_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n' # Write CMake include file - with open(self.working_path / "gemm_individual_targets.cmake", "w") as f: + with open( + self.working_path / f"{self.kernel_name_prefix}_individual_targets.cmake", + "w", + ) as f: f.write(cmake_code) - - -def _generate_single_kernel_individual(work_item): - """Worker function to generate a single individual kernel file""" - ( - tile_config, - trait_combo, - k_block_per_cu, - working_path, - gpu_target, - datatype, - layout, - config_json, - ) = work_item - - # Create a temporary builder instance for this worker - builder = GemmKernelBuilder(working_path, gpu_target, datatype, layout, config_json) - - try: - kernel_name, instance_code = builder._generate_kernel_instance( - tile_config, trait_combo, k_block_per_cu - ) - - # Create simplified filename without the "gemm_" prefix - # Remove "gemm_" from the beginning of kernel_name for the filename - simplified_name = kernel_name - if simplified_name.startswith("gemm_"): - simplified_name = simplified_name[5:] # Remove "gemm_" prefix - - # Write individual header file - header_file = working_path / f"gemm_single_{simplified_name}.hpp" - with open(header_file, "w") as f: - f.write(instance_code) - - return (kernel_name, trait_combo, tile_config) - except Exception as e: - print(f"Error generating individual kernel: {e}") - return None - - -def main(): - parser = argparse.ArgumentParser( - description="GEMM kernel instance builder with parallel support" - ) - parser.add_argument("--working_path", required=True, help="Working directory path") - parser.add_argument( - "--gpu_target", - required=True, - help="GPU target architecture", - ) - parser.add_argument( - "--datatype", - required=True, - choices=["fp16", "fp8", "bf16", "bf8"], - help="Data type", - ) - parser.add_argument( - "--layout", - required=True, - choices=["rcr", "rrr", "ccr", "crr"], - help="Matrix layout", - ) - parser.add_argument("--config_json", help="Configuration JSON file") - parser.add_argument( - "--num_workers", type=int, help="Number of parallel workers (default: auto)" - ) - parser.add_argument( - "--gen_all_individual", - action="store_true", - help="Generate individual kernel files", - ) - parser.add_argument( - "--gen_single", action="store_true", help="Generate a single kernel file" - ) - parser.add_argument("--kernel_name", help="Kernel name for single generation") - parser.add_argument( - "--tile_config", help="Tile configuration string for single generation" - ) - parser.add_argument( - "--trait_combo", help="Trait combination string for single generation" - ) - parser.add_argument( - "--list_kernels", - action="store_true", - help="List kernel configurations without generating files", - ) - - args = parser.parse_args() - - assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], ( - f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])" - ) - - layout_parts = args.layout.lower() - assert len(layout_parts) == 3, ( - f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)" - ) - assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], ( - f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)" - ) - assert layout_parts[2] == "r", ( - f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)" - ) - - # Create builder - builder = GemmKernelBuilder( - args.working_path, args.gpu_target, args.datatype, args.layout, args.config_json - ) - - if args.list_kernels: - builder.write_kernel_list() - elif args.gen_single: - # Generate a single kernel file - if not args.kernel_name or not args.tile_config or not args.trait_combo: - parser.error( - "--gen_single requires --kernel_name, --tile_config, and --trait_combo" - ) - - # Parse tile config - tile_parts = args.tile_config.split("_") - tile_dims = tile_parts[0].split("x") - warp_dims = tile_parts[1].split("x") - warp_tile_dims = tile_parts[2].split("x") - - tile_config = { - "tile_m": int(tile_dims[0]), - "tile_n": int(tile_dims[1]), - "tile_k": int(tile_dims[2]), - "warp_m": int(warp_dims[0]), - "warp_n": int(warp_dims[1]), - "warp_k": int(warp_dims[2]), - "warp_tile_m": int(warp_tile_dims[0]), - "warp_tile_n": int(warp_tile_dims[1]), - "warp_tile_k": int(warp_tile_dims[2]), - } - - # Parse trait combo - trait_parts = args.trait_combo.split("_") - trait_combo = ( - trait_parts[0], # pipeline - trait_parts[1], # epilogue - trait_parts[2], # scheduler - trait_parts[3] == "True", # pad_m - trait_parts[4] == "True", # pad_n - trait_parts[5] == "True", # pad_k - trait_parts[6] == "True", # persistent - ) - - k_block_per_cu = builder.config.get("k_block_per_cu") - if k_block_per_cu is None: - k_block_per_cu = 1 - - # Generate the kernel - kernel_name, instance_code = builder._generate_kernel_instance( - tile_config, trait_combo, k_block_per_cu - ) - - # Write the file - simplified_name = kernel_name - if simplified_name.startswith("gemm_"): - simplified_name = simplified_name[5:] - - header_file = builder.working_path / f"gemm_single_{simplified_name}.hpp" - with open(header_file, "w") as f: - f.write(instance_code) - - print(f"Generated {header_file}") - - elif args.gen_all_individual: - # Generate all individual kernel files - builder.run(args.num_workers) - else: - parser.error( - "Must specify one of: --list_kernels, --gen_all_individual, or --gen_single" - ) - - -if __name__ == "__main__": - main() diff --git a/tile_engine/ops/gemm_multi_d/CMakeLists.txt b/tile_engine/ops/gemm/gemm_multi_d/CMakeLists.txt similarity index 99% rename from tile_engine/ops/gemm_multi_d/CMakeLists.txt rename to tile_engine/ops/gemm/gemm_multi_d/CMakeLists.txt index 43164cd73c..b5f9a4b177 100644 --- a/tile_engine/ops/gemm_multi_d/CMakeLists.txt +++ b/tile_engine/ops/gemm/gemm_multi_d/CMakeLists.txt @@ -70,7 +70,6 @@ function(create_individual_gemm_multi_d_target datatype layout trait tile_config # Create the executable add_executable(${target_name} - # to save build time, exclude the target from "all" target of "gemm_multi_d" directory and its ancestors EXCLUDE_FROM_ALL ${GEMM_MULTI_D_SOURCE_DIR}/gemm_multi_d_benchmark_single.cpp ${instance_header} diff --git a/tile_engine/ops/gemm/configs/default_config.json b/tile_engine/ops/gemm/gemm_multi_d/configs/default_config.json similarity index 100% rename from tile_engine/ops/gemm/configs/default_config.json rename to tile_engine/ops/gemm/gemm_multi_d/configs/default_config.json diff --git a/tile_engine/ops/gemm/configs/user_provided_config.json b/tile_engine/ops/gemm/gemm_multi_d/configs/user_provided_config.json similarity index 100% rename from tile_engine/ops/gemm/configs/user_provided_config.json rename to tile_engine/ops/gemm/gemm_multi_d/configs/user_provided_config.json diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp b/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.hpp similarity index 100% rename from tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp rename to tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.hpp diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py b/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py old mode 100755 new mode 100644 similarity index 99% rename from tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py rename to tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py index 044e08baca..faf04a7de0 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py +++ b/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp b/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark_single.cpp similarity index 94% rename from tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp rename to tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark_single.cpp index 25ac342f3e..41d2f736e1 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp +++ b/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark_single.cpp @@ -80,12 +80,12 @@ void benchmark_single(const ck_tile::ArgParser& arg_parser) { // Use DataTypeTraits to get the actual type names from the generated header // The generated header defines ADataType, BDataType, AccDataType, CDataType - std::string dtype_a = ck_tile::DataTypeTraits::name; - std::string dtype_b = ck_tile::DataTypeTraits::name; - std::string dtype_acc = ck_tile::DataTypeTraits::name; - std::string dtype_c = ck_tile::DataTypeTraits::name; - std::string dtype_d0 = ck_tile::DataTypeTraits::name; - std::string dtype_d1 = ck_tile::DataTypeTraits::name; + std::string dtype_a = DataTypeTraits::name; + std::string dtype_b = DataTypeTraits::name; + std::string dtype_acc = DataTypeTraits::name; + std::string dtype_c = DataTypeTraits::name; + std::string dtype_d0 = DataTypeTraits::name; + std::string dtype_d1 = DataTypeTraits::name; // Layout names from the layout types std::string layout_a = ALayout::name; diff --git a/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_common.hpp b/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_common.hpp new file mode 100644 index 0000000000..899221547f --- /dev/null +++ b/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_common.hpp @@ -0,0 +1,100 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/pk_int4.hpp" + +//[TODO] This can be moved to commons +// DataTypeTraits for all supported types +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +// Helper function to determine if a layout is row-major +template +constexpr auto is_row_major(Layout) +{ + return ck_tile::bool_constant>{}; +} + +// Structure to hold kernel traits for dispatcher +struct KernelTraits +{ + std::string pipeline; // compv3, compv4, mem + std::string scheduler; // intrawave, interwave + std::string epilogue; // cshuffle, default + bool pad_m; + bool pad_n; + bool pad_k; + bool persistent; + + // Constructor with defaults + KernelTraits() + : pipeline("compv3"), + scheduler("intrawave"), + epilogue("cshuffle"), + pad_m(false), + pad_n(false), + pad_k(false), + persistent(false) + { + } +}; diff --git a/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_instance_builder.py b/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_instance_builder.py new file mode 100644 index 0000000000..1be8880bf0 --- /dev/null +++ b/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_instance_builder.py @@ -0,0 +1,330 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +import os +import argparse +import importlib.util +import multiprocessing +import concurrent.futures + + +def _import_gemm_kernel_builder(): + """Import validation utilities from commons directory.""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + parent_dir = os.path.dirname(current_dir) + + # Load the module dynamically + spec = importlib.util.spec_from_file_location( + "gemm_instance_builder", + os.path.join(parent_dir, "gemm_instance_builder.py"), + ) + gemm_builder_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(gemm_builder_module) + + return gemm_builder_module.GemmKernelBuilder + + +GemmKernelBuilder = _import_gemm_kernel_builder() + + +class GemmMultiDKernelBuilder(GemmKernelBuilder): + def __init__( + self, + kernel_name_prefix, + working_path, + gpu_target, + datatype, + layout, + elementwise_function, + config_json=None, + ): + super().__init__( + kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json + ) + self.elementwise_function = elementwise_function + + def _generate_all_individual(self, num_workers=None): + """Generate individual kernel files for separate compilation with parallel processing""" + if num_workers is None: + num_workers = min( + multiprocessing.cpu_count(), 8 + ) # Limit to avoid memory issues + + tile_configs = self._get_tile_configs() + trait_combos = self._generate_trait_combinations() + + # Prepare work items for parallel processing + work_items = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + work_items.append( + ( + tile_config, + trait_combo, + self.kernel_name_prefix, + self.working_path, + self.gpu_target, + self.datatype, + self.layout, + self.elementwise_function, + self.config_json, + ) + ) + + print( + f"Generating {len(work_items)} individual kernel files using {num_workers} workers..." + ) + print(f" Tile configs: {len(tile_configs)}") + print(f" Trait combinations: {len(trait_combos)}") + print(f" Total kernels: {len(work_items)}") + + # Show first few work items for debugging + if work_items: + print(" First work item example:") + tile_config, trait_combo = work_items[0][:2] + print(f" Tile config: {tile_config}") + print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits + + # Process work items in parallel + kernel_list = [] + completed = 0 + + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_workers + ) as executor: + # Submit all work items + print(f" Submitting {len(work_items)} tasks to executor...") + future_to_item = { + executor.submit(_generate_single_kernel_individual, item): item + for item in work_items + } + print(" All tasks submitted, waiting for completion...") + + # Collect results with progress reporting + for future in concurrent.futures.as_completed(future_to_item): + completed += 1 + if completed % 100 == 0 or completed == len(work_items): + print( + f" Progress: {completed}/{len(work_items)} kernels generated" + ) + + try: + result = future.result() + if result: + kernel_list.append(result) + except Exception as exc: + item = future_to_item[future] + print(f"Kernel generation failed for {item}: {exc}") + + # Sort kernel list for consistent ordering + kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name + + # Generate CMake include file for individual targets + self._generate_cmake_individual_targets(kernel_list) + + print( + f"Generated {len(kernel_list)} individual kernel files in {self.working_path}" + ) + + +def _generate_single_kernel_individual(work_item): + """Worker function to generate a single individual kernel file""" + ( + tile_config, + trait_combo, + kernel_name_prefix, + working_path, + gpu_target, + datatype, + layout, + elementwise_function, + config_json, + ) = work_item + + # Create a temporary builder instance for this worker + builder = GemmMultiDKernelBuilder( + kernel_name_prefix, + working_path, + gpu_target, + datatype, + layout, + elementwise_function, + config_json, + ) + + try: + kernel_name, instance_code = builder._generate_kernel_instance( + tile_config, trait_combo + ) + + # Create simplified filename without the "gemm_multi_d_" prefix + # Remove "gemm_multi_d_" from the beginning of kernel_name for the filename + simplified_name = kernel_name + if simplified_name.startswith("gemm_multi_d_"): + simplified_name = simplified_name[ + len(kernel_name_prefix) + 1 : + ] # Remove "gemm_multi_d_" prefix + + # Write individual header file + header_file = working_path / f"gemm_multi_d_single_{simplified_name}.hpp" + with open(header_file, "w") as f: + f.write(instance_code) + + return (kernel_name, trait_combo, tile_config) + except Exception as e: + print(f"Error generating individual kernel: {e}") + return None + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Multi D kernel instance builder with parallel support" + ) + parser.add_argument("--working_path", required=True, help="Working directory path") + parser.add_argument("--gpu_target", required=True, help="GPU target architecture") + parser.add_argument( + "--datatype", + required=True, + choices=["fp16"], + help="Data type", + ) + parser.add_argument( + "--layout", + required=True, + choices=["rcrr", "rrrr", "ccrr", "crrr"], + help="Matrix layout", + ) + parser.add_argument( + "--elementwise_function", + required=True, + help="Specify what element wise function for D, e.g. mul, add, passthrough", + ) + parser.add_argument("--config_json", help="Configuration JSON file") + parser.add_argument( + "--num_workers", type=int, help="Number of parallel workers (default: auto)" + ) + parser.add_argument( + "--gen_all_individual", + action="store_true", + help="Generate individual kernel files", + ) + parser.add_argument( + "--gen_single", action="store_true", help="Generate a single kernel file" + ) + parser.add_argument("--kernel_name", help="Kernel name for single generation") + parser.add_argument( + "--tile_config", help="Tile configuration string for single generation" + ) + parser.add_argument( + "--trait_combo", help="Trait combination string for single generation" + ) + parser.add_argument( + "--list_kernels", + action="store_true", + help="List kernel configurations without generating files", + ) + + args = parser.parse_args() + + assert args.datatype in ["fp16"], ( + f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16])" + ) + + layout_parts = args.layout.lower() + assert len(layout_parts) == 4, ( + f"Invalid layout string: {args.layout} (must be 4 characters like 'rcrr' where r stands for row major and c stands for column major)" + ) + assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], ( + f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)" + ) + assert layout_parts[2] == "r" and layout_parts[3] == "r", ( + f"Invalid matrix_c or d dimension in layout: {layout_parts[2]} and {layout_parts[3]} (must be 'r' only as currently we are supporting only row major)" + ) + + # Elementwise function name validation + elementwise_function = args.elementwise_function.lower() + + valid_functions = ["mul", "add", "passthrough"] + if elementwise_function not in valid_functions: + raise ValueError( + f"Invalid elementwise function: {elementwise_function}. " + f"Valid options are: {', '.join(valid_functions)}" + ) + + # Set the function name based on the elementwise function + if elementwise_function == "mul": + function_name = "MultiDMultiply" + elif elementwise_function == "add": + function_name = "MultiDAdd" + elif elementwise_function == "passthrough": + function_name = "PassThrough" + + args.elementwise_function = function_name + + # Create builder + kernel_name_prefix = "gemm_multi_d" + builder = GemmMultiDKernelBuilder( + kernel_name_prefix, + args.working_path, + args.gpu_target, + args.datatype, + args.layout, + args.elementwise_function, + args.config_json, + ) + + if args.list_kernels: + builder._list_kernels() + elif args.gen_single: + # Generate a single kernel file + if not args.kernel_name or not args.tile_config or not args.trait_combo: + parser.error( + "--gen_single requires --kernel_name, --tile_config, and --trait_combo" + ) + + # Parse tile config + tile_parts = args.tile_config.split("_") + tile_dims = tile_parts[0].split("x") + warp_dims = tile_parts[1].split("x") + warp_tile_dims = tile_parts[2].split("x") + + tile_config = { + "tile_m": int(tile_dims[0]), + "tile_n": int(tile_dims[1]), + "tile_k": int(tile_dims[2]), + "warp_m": int(warp_dims[0]), + "warp_n": int(warp_dims[1]), + "warp_k": int(warp_dims[2]), + "warp_tile_m": int(warp_tile_dims[0]), + "warp_tile_n": int(warp_tile_dims[1]), + "warp_tile_k": int(warp_tile_dims[2]), + } + + # Parse trait combo + trait_parts = args.trait_combo.split("_") + trait_combo = ( + trait_parts[0], # pipeline + trait_parts[1], # epilogue + trait_parts[2], # scheduler + trait_parts[3], # pad_m + trait_parts[4], # pad_n + trait_parts[5], # pad_k + trait_parts[6], # persistent + ) + + # Generate the kernel + builder._generate_kernel_instance( + tile_config, + trait_combo, + ) + elif args.gen_all_individual: + # Generate all individual kernel files + builder._generate_all_individual(args.num_workers) + else: + parser.error( + "Must specify one of: --list_kernels, --gen_all_individual, or --gen_single" + ) + + +if __name__ == "__main__": + main() diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp b/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_profiler.hpp similarity index 100% rename from tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp rename to tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_profiler.hpp diff --git a/tile_engine/ops/gemm_preshuffle/CMakeLists.txt b/tile_engine/ops/gemm/gemm_preshuffle/CMakeLists.txt similarity index 99% rename from tile_engine/ops/gemm_preshuffle/CMakeLists.txt rename to tile_engine/ops/gemm/gemm_preshuffle/CMakeLists.txt index c89fe236dd..ad93007fe3 100644 --- a/tile_engine/ops/gemm_preshuffle/CMakeLists.txt +++ b/tile_engine/ops/gemm/gemm_preshuffle/CMakeLists.txt @@ -67,7 +67,6 @@ function(create_individual_gemm_preshuffle_target datatype layout trait tile_con # Create the executable add_executable(${target_name} - # to save build time, exclude the target from "all" target of "gemm_preshuffle" directory and its ancestors EXCLUDE_FROM_ALL ${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_benchmark_single.cpp ${instance_header} diff --git a/tile_engine/ops/gemm_preshuffle/configs/default_config.json b/tile_engine/ops/gemm/gemm_preshuffle/configs/default_config.json similarity index 100% rename from tile_engine/ops/gemm_preshuffle/configs/default_config.json rename to tile_engine/ops/gemm/gemm_preshuffle/configs/default_config.json diff --git a/tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json b/tile_engine/ops/gemm/gemm_preshuffle/configs/user_provided_config.json similarity index 100% rename from tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json rename to tile_engine/ops/gemm/gemm_preshuffle/configs/user_provided_config.json diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.hpp similarity index 100% rename from tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp rename to tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.hpp diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py old mode 100755 new mode 100644 similarity index 99% rename from tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py rename to tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py index d8892be7d6..53ae6336fa --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py +++ b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp similarity index 94% rename from tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp rename to tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp index 0d5de02750..4fbb25f0c9 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp +++ b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp @@ -11,12 +11,12 @@ #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" -#include "ck_tile/ops/common/utils.hpp" #include "gemm_preshuffle_profiler.hpp" #include "gemm_preshuffle_common.hpp" // The kernel header is included via the compile command line with -include flag // It defines SelectedKernel struct and KERNEL_NAME +// DataTypeTraits are now defined in gemm_common.hpp // Create argument parser inline auto create_args(int argc, char* argv[]) @@ -77,12 +77,12 @@ inline auto create_args(int argc, char* argv[]) void benchmark_single(const ck_tile::ArgParser& arg_parser) { - // Use ck_tile::DataTypeTraits to get the actual type names from the generated header + // Use DataTypeTraits to get the actual type names from the generated header // The generated header defines ADataType, BDataType, AccDataType, CDataType - std::string dtype_a = ck_tile::DataTypeTraits::name; - std::string dtype_b = ck_tile::DataTypeTraits::name; - std::string dtype_acc = ck_tile::DataTypeTraits::name; - std::string dtype_c = ck_tile::DataTypeTraits::name; + std::string dtype_a = DataTypeTraits::name; + std::string dtype_b = DataTypeTraits::name; + std::string dtype_acc = DataTypeTraits::name; + std::string dtype_c = DataTypeTraits::name; // Layout names from the layout types std::string layout_a = ALayout::name; diff --git a/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_common.hpp b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_common.hpp new file mode 100644 index 0000000000..1b2cfe3735 --- /dev/null +++ b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_common.hpp @@ -0,0 +1,181 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/pk_int4.hpp" + +//[TODO] This can be moved to commons +// DataTypeTraits for all supported types +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +// Helper function to determine if a layout is row-major +template +constexpr auto is_row_major(Layout) +{ + return ck_tile::bool_constant>{}; +} + +// Structure to hold kernel traits for dispatcher +struct KernelTraits +{ + std::string pipeline; // preshufflev2 + std::string scheduler; // intrawave, interwave, default + std::string epilogue; // cshuffle, default + bool pad_m; + bool pad_n; + bool pad_k; + bool persistent; + + // Constructor with defaults + KernelTraits() + : pipeline("preshufflev2"), + scheduler("default"), + epilogue("default"), + pad_m(false), + pad_n(false), + pad_k(false), + persistent(false) + { + } +}; + +// Helper to extract traits from kernel name +inline KernelTraits extract_traits_from_name(const std::string& kernel_name) +{ + KernelTraits traits; + + // Extract pipeline + if(kernel_name.find("preshufflev2") != std::string::npos) + { + traits.pipeline = "preshufflev2"; + } + + // Extract scheduler + if(kernel_name.find("interwave") != std::string::npos) + { + traits.scheduler = "interwave"; + } + else if(kernel_name.find("intrawave") != std::string::npos) + { + traits.scheduler = "intrawave"; + } + else + { + traits.scheduler = "default"; + } + + // Extract epilogue + if(kernel_name.find("default") != std::string::npos && + kernel_name.find("default_") == std::string::npos) + { + traits.epilogue = "default"; + } + else + { + traits.epilogue = "cshuffle"; + } + + // Padding flags would need to be extracted from the kernel configuration + // For now, we'll leave them as false + + return traits; +} + +template +auto shuffle_b(const ck_tile::HostTensor& t, + ck_tile::index_t N_Warp_Tile, + ck_tile::index_t K_Warp_Tile) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + int divisor = N_Warp_Tile == 32 ? 2 : 4; + ck_tile::HostTensor t_view( + {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); +} + +template +auto shuffle_b_permuteN(const ck_tile::HostTensor& t, + ck_tile::index_t N_Warp_Tile, + ck_tile::index_t K_Warp_Tile, + ck_tile::index_t N_Tile, + ck_tile::index_t N_Warp) +{ + assert(t.get_lengths().size() == 2); + + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + int divisor = N_Warp_Tile == 32 ? 2 : 4; + int NRepeat = N_Tile / N_Warp_Tile / N_Warp; + ck_tile::HostTensor t_view({n_ / N_Tile, + N_Warp, + N_Warp_Tile, + NRepeat, + k_ / K_Warp_Tile, + divisor, + K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); +} diff --git a/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_instance_builder.py b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_instance_builder.py new file mode 100644 index 0000000000..6053be4d06 --- /dev/null +++ b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_instance_builder.py @@ -0,0 +1,300 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +import os +import argparse +import importlib.util +import multiprocessing +import concurrent.futures + + +def _import_gemm_kernel_builder(): + """Import validation utilities from commons directory.""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + parent_dir = os.path.dirname(current_dir) + + # Load the module dynamically + spec = importlib.util.spec_from_file_location( + "gemm_instance_builder", + os.path.join(parent_dir, "gemm_instance_builder.py"), + ) + gemm_builder_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(gemm_builder_module) + + return gemm_builder_module.GemmKernelBuilder + + +GemmKernelBuilder = _import_gemm_kernel_builder() + + +class GemmPreshuffleKernelBuilder(GemmKernelBuilder): + def __init__( + self, + kernel_name_prefix, + working_path, + gpu_target, + datatype, + layout, + config_json=None, + ): + super().__init__( + kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json + ) + + def _generate_all_individual(self, num_workers=None): + """Generate individual kernel files for separate compilation with parallel processing""" + if num_workers is None: + num_workers = min( + multiprocessing.cpu_count(), 8 + ) # Limit to avoid memory issues + + tile_configs = self._get_tile_configs() + trait_combos = self._generate_trait_combinations() + + # Prepare work items for parallel processing + work_items = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + work_items.append( + ( + tile_config, + trait_combo, + self.kernel_name_prefix, + self.working_path, + self.gpu_target, + self.datatype, + self.layout, + self.config_json, + ) + ) + + print( + f"Generating {len(work_items)} individual kernel files using {num_workers} workers..." + ) + print(f" Tile configs: {len(tile_configs)}") + print(f" Trait combinations: {len(trait_combos)}") + print(f" Total kernels: {len(work_items)}") + + # Show first few work items for debugging + if work_items: + print(" First work item example:") + tile_config, trait_combo = work_items[0][:2] + print(f" Tile config: {tile_config}") + print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits + + # Process work items in parallel + kernel_list = [] + completed = 0 + + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_workers + ) as executor: + # Submit all work items + print(f" Submitting {len(work_items)} tasks to executor...") + future_to_item = { + executor.submit(_generate_single_kernel_individual, item): item + for item in work_items + } + print(" All tasks submitted, waiting for completion...") + + # Collect results with progress reporting + for future in concurrent.futures.as_completed(future_to_item): + completed += 1 + if completed % 100 == 0 or completed == len(work_items): + print( + f" Progress: {completed}/{len(work_items)} kernels generated" + ) + + try: + result = future.result() + if result: + kernel_list.append(result) + except Exception as exc: + item = future_to_item[future] + print(f"Kernel generation failed for {item}: {exc}") + + # Sort kernel list for consistent ordering + kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name + + # Generate CMake include file for individual targets + self._generate_cmake_individual_targets(kernel_list) + + print( + f"Generated {len(kernel_list)} individual kernel files in {self.working_path}" + ) + + +def _generate_single_kernel_individual(work_item): + """Worker function to generate a single individual kernel file""" + ( + tile_config, + trait_combo, + kernel_name_prefix, + working_path, + gpu_target, + datatype, + layout, + config_json, + ) = work_item + + # Create a temporary builder instance for this worker + builder = GemmPreshuffleKernelBuilder( + kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json + ) + + try: + kernel_name, instance_code = builder._generate_kernel_instance( + tile_config, trait_combo + ) + + # Create simplified filename without the "gemm_preshuffle_" prefix + # Remove "gemm_preshuffle_" from the beginning of kernel_name for the filename + simplified_name = kernel_name + if simplified_name.startswith("gemm_preshuffle_"): + simplified_name = simplified_name[ + len(kernel_name_prefix) + 1 : + ] # Remove "gemm_preshuffle_" prefix + + # Write individual header file + header_file = working_path / f"gemm_preshuffle_single_{simplified_name}.hpp" + with open(header_file, "w") as f: + f.write(instance_code) + + return (kernel_name, trait_combo, tile_config) + except Exception as e: + print(f"Error generating individual kernel: {e}") + return None + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM kernel instance builder with parallel support" + ) + parser.add_argument("--working_path", required=True, help="Working directory path") + parser.add_argument( + "--gpu_target", + required=True, + help="GPU target architecture", + ) + parser.add_argument( + "--datatype", + required=True, + choices=["fp16", "fp8", "bf16", "bf8"], + help="Data type", + ) + parser.add_argument( + "--layout", + required=True, + choices=["rcr"], + help="Matrix layout", + ) + parser.add_argument("--config_json", required=True, help="Configuration JSON file") + parser.add_argument( + "--num_workers", type=int, help="Number of parallel workers (default: auto)" + ) + parser.add_argument( + "--gen_all_individual", + action="store_true", + help="Generate individual kernel files", + ) + parser.add_argument( + "--gen_single", action="store_true", help="Generate a single kernel file" + ) + parser.add_argument("--kernel_name", help="Kernel name for single generation") + parser.add_argument( + "--tile_config", help="Tile configuration string for single generation" + ) + parser.add_argument( + "--trait_combo", help="Trait combination string for single generation" + ) + parser.add_argument( + "--list_kernels", + action="store_true", + help="List kernel configurations without generating files", + ) + + args = parser.parse_args() + + assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], ( + f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])" + ) + + layout_parts = args.layout.lower() + assert len(layout_parts) == 3, ( + f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)" + ) + assert layout_parts[0] in ["r"] and layout_parts[1] in ["c"], ( + f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a must be 'r' for row major and matrix_b must be 'c' for column major as it is the only supported layout for preshuffle)" + ) + assert layout_parts[2] == "r", ( + f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)" + ) + + # Create builder + kernel_name_prefix = "gemm_preshuffle" + builder = GemmPreshuffleKernelBuilder( + kernel_name_prefix, + args.working_path, + args.gpu_target, + args.datatype, + args.layout, + args.config_json, + ) + + if args.list_kernels: + # Fast listing mode - just write kernel list without generating files + builder._list_kernels() + elif args.gen_single: + # Generate a single kernel file + if not args.kernel_name or not args.tile_config or not args.trait_combo: + parser.error( + "--gen_single requires --kernel_name, --tile_config, and --trait_combo" + ) + # Parse tile config + tile_parts = args.tile_config.split("_") + tile_dims = tile_parts[0].split("x") + warp_dims = tile_parts[1].split("x") + warp_tile_dims = tile_parts[2].split("x") + + tile_config = { + "tile_m": int(tile_dims[0]), + "tile_n": int(tile_dims[1]), + "tile_k": int(tile_dims[2]), + "warp_m": int(warp_dims[0]), + "warp_n": int(warp_dims[1]), + "warp_k": int(warp_dims[2]), + "warp_tile_m": int(warp_tile_dims[0]), + "warp_tile_n": int(warp_tile_dims[1]), + "warp_tile_k": int(warp_tile_dims[2]), + } + + # Parse trait combo + trait_parts = args.trait_combo.split("_") + trait_combo = ( + trait_parts[0], # pipeline + trait_parts[1], # epilogue + trait_parts[2], # scheduler + trait_parts[3] == "True", # pad_m + trait_parts[4] == "True", # pad_n + trait_parts[5] == "True", # pad_k + trait_parts[6] == "True", # persistent + ) + + # Generate the kernel + builder._generate_kernel_instance( + tile_config, + trait_combo, + ) + + elif args.gen_all_individual: + # Generate all individual kernel files + builder._generate_all_individual(args.num_workers) + pass + else: + parser.error( + "Must specify one of: --list_kernels, --gen_all_individual, or --gen_single" + ) + + +if __name__ == "__main__": + main() diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_profiler.hpp similarity index 94% rename from tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp rename to tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_profiler.hpp index cad53b472f..739bd7e677 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp +++ b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_profiler.hpp @@ -111,30 +111,21 @@ class GemmProfiler c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - struct GemmConfig - { - ck_tile::index_t N_Warp_Tile; - ck_tile::index_t K_Warp_Tile; - ck_tile::index_t N_Tile; - ck_tile::index_t N_Warp; - }; - for(const auto& callable : callables) { - GemmConfig gemmConfig = {}; - gemmConfig.N_Warp_Tile = std::get<1>(config.warp_tile_dims); - gemmConfig.K_Warp_Tile = std::get<2>(config.warp_tile_dims); - gemmConfig.N_Tile = std::get<1>(config.tile_dims); - gemmConfig.N_Warp = std::get<1>(config.warp_dims); + ck_tile::index_t N_Warp_Tile = std::get<1>(config.warp_tile_dims); + ck_tile::index_t K_Warp_Tile = std::get<2>(config.warp_tile_dims); + ck_tile::index_t N_Tile = std::get<1>(config.tile_dims); + ck_tile::index_t N_Warp = std::get<1>(config.warp_dims); ck_tile::HostTensor b_shuffle_host = [&]() { if(config.permuteN) { - return ck_tile::shuffle_b_permuteN(b_k_n, gemmConfig); + return shuffle_b_permuteN(b_k_n, N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp); } else { - return ck_tile::shuffle_b(b_k_n, gemmConfig); + return shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile); } }(); diff --git a/tile_engine/ops/gemm/gemm_universal/CMakeLists.txt b/tile_engine/ops/gemm/gemm_universal/CMakeLists.txt new file mode 100644 index 0000000000..7505fcd6d0 --- /dev/null +++ b/tile_engine/ops/gemm/gemm_universal/CMakeLists.txt @@ -0,0 +1,309 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +set(GEMM_UNIVERSAL_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM Universal (semicolon-separated)") +set(GEMM_UNIVERSAL_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM Universal (semicolon-separated)") +set(GEMM_UNIVERSAL_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") +option(ENABLE_CCACHE_GEMM_UNIVERSAL "Enable ccache for GEMM Universal ops compilation" OFF) + +# Store the directory path for use in functions +set(GEMM_UNIVERSAL_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) + +# Function to create individual GEMM Universal targets +function(create_individual_gemm_universal_target datatype layout trait tile_config config_json) + # Use the parent scope GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL variable + if(NOT GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL) + message(WARNING "Skipping individual GEMM Universal target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets") + return() + endif() + + # Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k + # First split by underscore to get three groups + string(REPLACE "_" ";" config_groups ${tile_config}) + list(GET config_groups 0 tile_dims) # e.g., 256x256x32 + list(GET config_groups 1 warp_dims) # e.g., 4x1x1 + list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16 + + # Parse tile dimensions + string(REPLACE "x" ";" tile_parts ${tile_dims}) + list(GET tile_parts 0 tile_m) + list(GET tile_parts 1 tile_n) + list(GET tile_parts 2 tile_k) + + # Parse warp dimensions + string(REPLACE "x" ";" warp_parts ${warp_dims}) + list(GET warp_parts 0 warp_m) + list(GET warp_parts 1 warp_n) + list(GET warp_parts 2 warp_k) + + # Parse warp tile dimensions + string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims}) + list(GET warp_tile_parts 0 warp_tile_m) + list(GET warp_tile_parts 1 warp_tile_n) + list(GET warp_tile_parts 2 warp_tile_k) + + set(target_name "benchmark_gemm_universal_${datatype}_${layout}_${trait}_${tile_config}") + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") + + # Generate the single instance header for this kernel + set(instance_header "${working_path}/gemm_universal_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") + + # Add custom command to generate the header file at build time + add_custom_command( + OUTPUT ${instance_header} + COMMAND ${Python3_EXECUTABLE} ${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_universal_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${config_json} + --gen_single + --kernel_name "gemm_universal_${datatype}_${layout}_${trait}_${tile_config}" + --tile_config "${tile_config}" + --trait_combo "${trait}" + --gpu_target "${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}" + DEPENDS ${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_universal_instance_builder.py ${config_json} + COMMENT "Generating ${instance_header}" + ) + + # Create the executable + add_executable(${target_name} + EXCLUDE_FROM_ALL + ${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_benchmark_single.cpp + ${instance_header} + ) + + # Set GPU architectures + set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}) + + # Set compile definitions + target_compile_definitions(${target_name} PRIVATE + GEMM_UNIVERSAL_SINGLE_INSTANCE_HPP="${instance_header}" + ) + + # Include directories + target_include_directories(${target_name} PRIVATE + ${GEMM_UNIVERSAL_SOURCE_DIR} + ${working_path} + ) + + # Compile options + target_compile_options(${target_name} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + -include ${instance_header} + ) + + # Add to collection targets + add_dependencies(benchmark_gemm_universal_all ${target_name}) + add_dependencies(benchmark_gemm_universal_${datatype} ${target_name}) + add_dependencies(benchmark_gemm_universal_${layout} ${target_name}) + add_dependencies(benchmark_gemm_universal_${datatype}_${layout} ${target_name}) + + # Add to trait-specific targets + string(REPLACE "_" ";" trait_parts ${trait}) + list(GET trait_parts 0 pipeline) + list(GET trait_parts 1 epilogue) + list(GET trait_parts 2 scheduler) + + add_dependencies(benchmark_gemm_universal_${pipeline}_pipeline ${target_name}) + add_dependencies(benchmark_gemm_universal_${epilogue}_epilogue ${target_name}) + add_dependencies(benchmark_gemm_universal_${scheduler}_scheduler ${target_name}) +endfunction() + +# Function to build individual GEMM Universal targets +function(build_individual_gemm_universal_targets datatype layout) + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") + + # Choose config file + # Priority order: + # 1. Environment variable GEMM_UNIVERSAL_CONFIG_FILE + # 2. CMake variable GEMM_UNIVERSAL_CONFIG_FILE + # 3. Default based on layout + + # Check environment variable first + if(DEFINED ENV{GEMM_UNIVERSAL_CONFIG_FILE} AND NOT "$ENV{GEMM_UNIVERSAL_CONFIG_FILE}" STREQUAL "") + set(config_filename "$ENV{GEMM_UNIVERSAL_CONFIG_FILE}") + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}") + message(VERBOSE " Using config from environment variable: ${config_filename}") + elseif(NOT "${GEMM_UNIVERSAL_CONFIG_FILE}" STREQUAL "") + # Use CMake variable if set + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_UNIVERSAL_CONFIG_FILE}") + message(VERBOSE " Using custom config: ${GEMM_UNIVERSAL_CONFIG_FILE}") + else() + # Use default config for all layouts + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") + message(VERBOSE " Using default config for layout ${layout}") + endif() + + # Check if config file exists + if(NOT EXISTS ${json_blob}) + message(FATAL_ERROR "Config file not found: ${json_blob}") + endif() + + # Determine number of workers for parallel generation + if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL}) + set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL}) + else() + # Use processor count but limit to avoid memory issues + cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES) + math(EXPR num_workers "${num_cores}") + if(num_workers GREATER 8) + set(num_workers 8) + endif() + endif() + + # Generate individual kernel files using parallel version + message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...") + message(VERBOSE " Working path: ${working_path}") + message(VERBOSE " Config file: ${json_blob}") + message(VERBOSE " Python executable: ${Python3_EXECUTABLE}") + message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py") + + # Create working directory first + file(MAKE_DIRECTORY ${working_path}) + + message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${json_blob} + --gpu_target ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL} + --list_kernels ") + + # First, just list the kernels (fast operation) + message(VERBOSE " Listing kernel configurations...") + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${json_blob} + --gpu_target ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL} + --list_kernels + WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} + RESULT_VARIABLE ret + OUTPUT_VARIABLE list_output + ERROR_VARIABLE list_error + ) + + if(NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}") + endif() + + # Read kernel count + if(EXISTS ${working_path}/gemm_universal_kernel_count.txt) + file(READ ${working_path}/gemm_universal_kernel_count.txt kernel_count) + string(STRIP "${kernel_count}" kernel_count) + message(VERBOSE " Found ${kernel_count} kernel configurations") + else() + message(FATAL_ERROR "Kernel count file not found") + endif() + + # Read kernel list and create targets + if(EXISTS ${working_path}/gemm_universal_kernel_list.txt) + file(STRINGS ${working_path}/gemm_universal_kernel_list.txt kernel_lines) + foreach(line IN LISTS kernel_lines) + # Parse line: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) + + # Create individual target + create_individual_gemm_universal_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}") + endforeach() + else() + message(FATAL_ERROR "Kernel list file not found") + endif() +endfunction() + +# Main build logic - Only individual builds supported +message(VERBOSE "=== Starting Tile Engine GEMM Universal Configuration ===") +message(VERBOSE "GEMM_UNIVERSAL_DATATYPE: ${GEMM_UNIVERSAL_DATATYPE}") +message(VERBOSE "GEMM_UNIVERSAL_LAYOUT: ${GEMM_UNIVERSAL_LAYOUT}") +message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + +# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201 +set(GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL "") +set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201") + +foreach(target IN LISTS SUPPORTED_GPU_TARGETS) + if(target IN_LIST DESIRED_TARGETS) + list(APPEND GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL ${target}) + message(VERBOSE " Adding GPU target: ${target}") + endif() +endforeach() + +# Skip build if no matching targets found +if(NOT GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL) + message(WARNING "Skipping Tile Engine GEMM Universal build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") +else() + message(VERBOSE "Building individual GEMM Universal targets for GPU targets: ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}") + + # Enable parallel compilation optimizations + # Set up job pools for better parallel compilation control + set_property(GLOBAL PROPERTY JOB_POOLS + compile_heavy=4 # Limit heavy compilations to prevent OOM + compile_normal=16 # Allow more parallel normal compilations + ) + + # Enable compiler cache if available and explicitly requested + # Disabled by default due to permission issues in CI environments + if(ENABLE_CCACHE_GEMM_UNIVERSAL) + find_program(CCACHE_PROGRAM ccache) + if(CCACHE_PROGRAM) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) + message(VERBOSE "Using ccache for faster compilation") + else() + message(WARNING "ccache requested but not found") + endif() + else() + message(VERBOSE "ccache disabled for GEMM Universal ops (use -DENABLE_CCACHE_GEMM_UNIVERSAL=ON to enable)") + endif() + + # Create master collection targets + add_custom_target(benchmark_gemm_universal_all) + + # Create datatype collection targets + foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE) + add_custom_target(benchmark_gemm_universal_${dt}) + endforeach() + + # Create layout collection targets + foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT) + add_custom_target(benchmark_gemm_universal_${l}) + endforeach() + + # Create combined collection targets + foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE) + foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT) + add_custom_target(benchmark_gemm_universal_${dt}_${l}) + endforeach() + endforeach() + + # Create trait-based collection targets + # These are common trait components used across all GEMM Universal kernels + set(GEMM_UNIVERSAL_PIPELINES "mem;compv3;compv4") + set(GEMM_UNIVERSAL_EPILOGUES "default;cshuffle") + set(GEMM_UNIVERSAL_SCHEDULERS "intrawave;interwave") + + foreach(pipeline IN LISTS GEMM_UNIVERSAL_PIPELINES) + add_custom_target(benchmark_gemm_universal_${pipeline}_pipeline) + endforeach() + + foreach(epilogue IN LISTS GEMM_UNIVERSAL_EPILOGUES) + add_custom_target(benchmark_gemm_universal_${epilogue}_epilogue) + endforeach() + + foreach(scheduler IN LISTS GEMM_UNIVERSAL_SCHEDULERS) + add_custom_target(benchmark_gemm_universal_${scheduler}_scheduler) + endforeach() + + # Build individual targets for each datatype/layout combination + foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE) + foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT) + build_individual_gemm_universal_targets(${dt} ${l}) + endforeach() + endforeach() +endif() diff --git a/tile_engine/ops/gemm_multi_d/configs/default_config.json b/tile_engine/ops/gemm/gemm_universal/configs/default_config.json similarity index 100% rename from tile_engine/ops/gemm_multi_d/configs/default_config.json rename to tile_engine/ops/gemm/gemm_universal/configs/default_config.json diff --git a/tile_engine/ops/gemm_multi_d/configs/user_provided_config.json b/tile_engine/ops/gemm/gemm_universal/configs/user_provided_config.json similarity index 87% rename from tile_engine/ops/gemm_multi_d/configs/user_provided_config.json rename to tile_engine/ops/gemm/gemm_universal/configs/user_provided_config.json index 40a7dda6cc..ddf30bb69b 100644 --- a/tile_engine/ops/gemm_multi_d/configs/user_provided_config.json +++ b/tile_engine/ops/gemm/gemm_universal/configs/user_provided_config.json @@ -2,12 +2,12 @@ "tile_config": { "tile_m": { "values": [ - 64 + 128 ] }, "tile_n": { "values": [ - 192 + 128 ] }, "tile_k": { @@ -17,12 +17,12 @@ }, "warp_m": { "values": [ - 2 + 4 ] }, "warp_n": { "values": [ - 2 + 1 ] }, "warp_k": { @@ -32,24 +32,24 @@ }, "warp_tile_m": { "values": [ - 32 + 16 ] }, "warp_tile_n": { "values": [ - 32 + 16 ] }, "warp_tile_k": { "values": [ - 8 + 16 ] } }, "trait_config": { "pipeline": { "values": [ - "compv4" + "mem" ] }, "scheduler": { @@ -59,7 +59,7 @@ }, "epilogue": { "values": [ - "cshuffle" + "default" ] }, "pad_m": { @@ -83,5 +83,5 @@ ] } }, - "k_block_per_cu": 1 + "k_block_per_cu": 2 } \ No newline at end of file diff --git a/tile_engine/ops/gemm/gemm_benchmark.hpp b/tile_engine/ops/gemm/gemm_universal/gemm_benchmark.hpp similarity index 100% rename from tile_engine/ops/gemm/gemm_benchmark.hpp rename to tile_engine/ops/gemm/gemm_universal/gemm_benchmark.hpp diff --git a/tile_engine/ops/gemm/gemm_benchmark.py b/tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py old mode 100755 new mode 100644 similarity index 99% rename from tile_engine/ops/gemm/gemm_benchmark.py rename to tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py index cc04dbe0db..b7424c6d1d --- a/tile_engine/ops/gemm/gemm_benchmark.py +++ b/tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT diff --git a/tile_engine/ops/gemm/gemm_benchmark_single.cpp b/tile_engine/ops/gemm/gemm_universal/gemm_benchmark_single.cpp similarity index 93% rename from tile_engine/ops/gemm/gemm_benchmark_single.cpp rename to tile_engine/ops/gemm/gemm_universal/gemm_benchmark_single.cpp index 26f3a3928a..6323c066a1 100644 --- a/tile_engine/ops/gemm/gemm_benchmark_single.cpp +++ b/tile_engine/ops/gemm/gemm_universal/gemm_benchmark_single.cpp @@ -11,12 +11,12 @@ #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" -#include "ck_tile/ops/common/utils.hpp" #include "gemm_profiler.hpp" #include "gemm_common.hpp" // The kernel header is included via the compile command line with -include flag // It defines SelectedKernel struct and KERNEL_NAME +// DataTypeTraits are now defined in gemm_common.hpp // Create argument parser inline auto create_args(int argc, char* argv[]) @@ -77,12 +77,12 @@ inline auto create_args(int argc, char* argv[]) void benchmark_single(const ck_tile::ArgParser& arg_parser) { - // Use ck_tile::DataTypeTraits to get the actual type names from the generated header + // Use DataTypeTraits to get the actual type names from the generated header // The generated header defines ADataType, BDataType, AccDataType, CDataType - std::string dtype_a = ck_tile::DataTypeTraits::name; - std::string dtype_b = ck_tile::DataTypeTraits::name; - std::string dtype_acc = ck_tile::DataTypeTraits::name; - std::string dtype_c = ck_tile::DataTypeTraits::name; + std::string dtype_a = DataTypeTraits::name; + std::string dtype_b = DataTypeTraits::name; + std::string dtype_acc = DataTypeTraits::name; + std::string dtype_c = DataTypeTraits::name; // Layout names from the layout types std::string layout_a = ALayout::name; diff --git a/tile_engine/ops/gemm/gemm_universal/gemm_common.hpp b/tile_engine/ops/gemm/gemm_universal/gemm_common.hpp new file mode 100644 index 0000000000..899221547f --- /dev/null +++ b/tile_engine/ops/gemm/gemm_universal/gemm_common.hpp @@ -0,0 +1,100 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/pk_int4.hpp" + +//[TODO] This can be moved to commons +// DataTypeTraits for all supported types +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +// Helper function to determine if a layout is row-major +template +constexpr auto is_row_major(Layout) +{ + return ck_tile::bool_constant>{}; +} + +// Structure to hold kernel traits for dispatcher +struct KernelTraits +{ + std::string pipeline; // compv3, compv4, mem + std::string scheduler; // intrawave, interwave + std::string epilogue; // cshuffle, default + bool pad_m; + bool pad_n; + bool pad_k; + bool persistent; + + // Constructor with defaults + KernelTraits() + : pipeline("compv3"), + scheduler("intrawave"), + epilogue("cshuffle"), + pad_m(false), + pad_n(false), + pad_k(false), + persistent(false) + { + } +}; diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_universal/gemm_profiler.hpp similarity index 100% rename from tile_engine/ops/gemm/gemm_profiler.hpp rename to tile_engine/ops/gemm/gemm_universal/gemm_profiler.hpp diff --git a/tile_engine/ops/gemm/gemm_universal/gemm_universal_instance_builder.py b/tile_engine/ops/gemm/gemm_universal/gemm_universal_instance_builder.py new file mode 100644 index 0000000000..08f0e7e284 --- /dev/null +++ b/tile_engine/ops/gemm/gemm_universal/gemm_universal_instance_builder.py @@ -0,0 +1,295 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +import os +import argparse +import importlib.util +import multiprocessing +import concurrent.futures + + +def _import_gemm_kernel_builder(): + """Import validation utilities from commons directory.""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + parent_dir = os.path.dirname(current_dir) + + # Load the module dynamically + spec = importlib.util.spec_from_file_location( + "gemm_instance_builder", + os.path.join(parent_dir, "gemm_instance_builder.py"), + ) + gemm_builder_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(gemm_builder_module) + + return gemm_builder_module.GemmKernelBuilder + + +GemmKernelBuilder = _import_gemm_kernel_builder() + + +class GemmUniversalKernelBuilder(GemmKernelBuilder): + def __init__( + self, + kernel_name_prefix, + working_path, + gpu_target, + datatype, + layout, + config_json=None, + ): + super().__init__( + kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json + ) + + def _generate_all_individual(self, num_workers=None): + """Generate individual kernel files for separate compilation with parallel processing""" + if num_workers is None: + num_workers = min( + multiprocessing.cpu_count(), 8 + ) # Limit to avoid memory issues + + tile_configs = self._get_tile_configs() + trait_combos = self._generate_trait_combinations() + + # Prepare work items for parallel processing + work_items = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + work_items.append( + ( + tile_config, + trait_combo, + self.kernel_name_prefix, + self.working_path, + self.gpu_target, + self.datatype, + self.layout, + self.config_json, + ) + ) + print( + f"Generating {len(work_items)} individual kernel files using {num_workers} workers..." + ) + print(f" Tile configs: {len(tile_configs)}") + print(f" Trait combinations: {len(trait_combos)}") + print(f" Total kernels: {len(work_items)}") + + # Show first few work items for debugging + if work_items: + print(" First work item example:") + tile_config, trait_combo = work_items[0][:2] + print(f" Tile config: {tile_config}") + print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits + + # Process work items in parallel + kernel_list = [] + completed = 0 + + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_workers + ) as executor: + # Submit all work items + print(f" Submitting {len(work_items)} tasks to executor...") + future_to_item = { + executor.submit(_generate_single_kernel_individual, item): item + for item in work_items + } + print(" All tasks submitted, waiting for completion...") + + # Collect results with progress reporting + for future in concurrent.futures.as_completed(future_to_item): + completed += 1 + if completed % 100 == 0 or completed == len(work_items): + print( + f" Progress: {completed}/{len(work_items)} kernels generated" + ) + try: + result = future.result() + if result: + kernel_list.append(result) + except Exception as exc: + item = future_to_item[future] + print(f"Kernel generation failed for {item}: {exc}") + + # Sort kernel list for consistent ordering + kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name + + # Generate CMake include file for individual targets + self._generate_cmake_individual_targets(kernel_list) + + print( + f"Generated {len(kernel_list)} individual kernel files in {self.working_path}" + ) + + +def _generate_single_kernel_individual(work_item): + """Worker function to generate a single individual kernel file""" + ( + tile_config, + trait_combo, + kernel_name_prefix, + working_path, + gpu_target, + datatype, + layout, + config_json, + ) = work_item + + # Create a temporary builder instance for this worker + builder = GemmUniversalKernelBuilder( + kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json + ) + + try: + kernel_name, instance_code = builder._generate_kernel_instance( + tile_config, trait_combo + ) + + # Create simplified filename without the "gemm_universal_" prefix + # Remove "gemm_universal_" from the beginning of kernel_name for the filename + simplified_name = kernel_name + if simplified_name.startswith("gemm_universal_"): + simplified_name = simplified_name[ + len(kernel_name_prefix) + 1 : + ] # Remove "gemm_universal" prefix + + # Write individual header file + header_file = working_path / f"gemm_universal_single_{simplified_name}.hpp" + with open(header_file, "w") as f: + f.write(instance_code) + + return (kernel_name, trait_combo, tile_config) + except Exception as e: + print(f"Error generating individual kernel: {e}") + return None + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Universal kernel instance builder with parallel support" + ) + parser.add_argument("--working_path", required=True, help="Working directory path") + parser.add_argument( + "--gpu_target", + required=True, + help="GPU target architecture", + ) + parser.add_argument( + "--datatype", + required=True, + choices=["fp16", "fp8", "bf16", "bf8"], + help="Data type", + ) + parser.add_argument( + "--layout", + required=True, + choices=["rcr", "rrr", "ccr", "crr"], + help="Matrix layout", + ) + parser.add_argument("--config_json", help="Configuration JSON file") + parser.add_argument( + "--num_workers", type=int, help="Number of parallel workers (default: auto)" + ) + parser.add_argument( + "--gen_all_individual", + action="store_true", + help="Generate individual kernel files", + ) + parser.add_argument( + "--gen_single", action="store_true", help="Generate a single kernel file" + ) + parser.add_argument("--kernel_name", help="Kernel name for single generation") + parser.add_argument( + "--tile_config", help="Tile configuration string for single generation" + ) + parser.add_argument( + "--trait_combo", help="Trait combination string for single generation" + ) + parser.add_argument( + "--list_kernels", + action="store_true", + help="List kernel configurations without generating files", + ) + + args = parser.parse_args() + + assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], ( + f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])" + ) + + layout_parts = args.layout.lower() + assert len(layout_parts) == 3, ( + f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)" + ) + assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], ( + f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)" + ) + assert layout_parts[2] == "r", ( + f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)" + ) + + kernel_name_prefix = "gemm_universal" + builder = GemmUniversalKernelBuilder( + kernel_name_prefix, + args.working_path, + args.gpu_target, + args.datatype, + args.layout, + args.config_json, + ) + + if args.list_kernels: + builder._list_kernels() + elif args.gen_single: + # Generate a single kernel file input validation + if not args.kernel_name or not args.tile_config or not args.trait_combo: + parser.error( + "--gen_single requires --kernel_name, --tile_config, and --trait_combo" + ) + + # Parse tile config + tile_parts = args.tile_config.split("_") + tile_dims = tile_parts[0].split("x") + warp_dims = tile_parts[1].split("x") + warp_tile_dims = tile_parts[2].split("x") + + tile_config = { + "tile_m": int(tile_dims[0]), + "tile_n": int(tile_dims[1]), + "tile_k": int(tile_dims[2]), + "warp_m": int(warp_dims[0]), + "warp_n": int(warp_dims[1]), + "warp_k": int(warp_dims[2]), + "warp_tile_m": int(warp_tile_dims[0]), + "warp_tile_n": int(warp_tile_dims[1]), + "warp_tile_k": int(warp_tile_dims[2]), + } + + # Parse trait combo + trait_parts = args.trait_combo.split("_") + trait_combo = ( + trait_parts[0], # pipeline + trait_parts[1], # epilogue + trait_parts[2], # scheduler + trait_parts[3] == "True", # pad_m + trait_parts[4] == "True", # pad_n + trait_parts[5] == "True", # pad_k + trait_parts[6] == "True", # persistent + ) + + # Generate the kernel + builder._generate_kernel_instance( + tile_config, + trait_combo, + ) + elif args.gen_all_individual: + # Generate all individual kernel files + builder._generate_all_individual(args.num_workers) + else: + parser.error( + "Must specify one of: --list_kernels, --gen_all_individual, or --gen_single" + ) + + +if __name__ == "__main__": + main() diff --git a/tile_engine/ops/commons/gemm_validation_utils.py b/tile_engine/ops/gemm/gemm_validation_utils.py similarity index 99% rename from tile_engine/ops/commons/gemm_validation_utils.py rename to tile_engine/ops/gemm/gemm_validation_utils.py index 37a944aef7..cae6123307 100644 --- a/tile_engine/ops/commons/gemm_validation_utils.py +++ b/tile_engine/ops/gemm/gemm_validation_utils.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT @@ -660,7 +659,6 @@ def validate_whole_wg_cover_configuration( ) if not wg_cover_core_valid: - print("I am here 3") logging.debug( f"whole workgroup cover failed for Matrix B distribution: {wg_cover_core_error}" ) diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp deleted file mode 100644 index 1fdc63b33b..0000000000 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" -#include "ck_tile/core/numeric/integer.hpp" -#include "ck_tile/core/numeric/pk_int4.hpp" - -// Helper function to determine if a layout is row-major -template -constexpr auto is_row_major(Layout) -{ - return ck_tile::bool_constant>{}; -} - -// Structure to hold kernel traits for dispatcher -struct KernelTraits -{ - std::string pipeline; // compv3, compv4, mem - std::string scheduler; // intrawave, interwave - std::string epilogue; // cshuffle, default - bool pad_m; - bool pad_n; - bool pad_k; - bool persistent; - - // Constructor with defaults - KernelTraits() - : pipeline("compv3"), - scheduler("intrawave"), - epilogue("cshuffle"), - pad_m(false), - pad_n(false), - pad_k(false), - persistent(false) - { - } -}; diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py deleted file mode 100644 index f04c2a2c96..0000000000 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py +++ /dev/null @@ -1,891 +0,0 @@ -#!/usr/bin/env python -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - - -import os -import json -import argparse -import itertools -import multiprocessing -import concurrent.futures -from pathlib import Path -import logging -import importlib.util - - -def _import_validation_utils(): - """Import validation utilities from commons directory.""" - current_dir = os.path.dirname(os.path.abspath(__file__)) - parent_dir = os.path.dirname(current_dir) - - # Load the module dynamically - spec = importlib.util.spec_from_file_location( - "validation_utils", - os.path.join(parent_dir, "commons", "gemm_validation_utils.py"), - ) - validation_utils = importlib.util.module_from_spec(spec) - spec.loader.exec_module(validation_utils) - - return validation_utils - - -# Import validation functions -_validation_utils = _import_validation_utils() -is_tile_config_valid = _validation_utils.is_tile_config_valid -is_trait_combination_valid = _validation_utils.is_trait_combination_valid -get_dtype_string = _validation_utils.get_dtype_string -get_abcd_layouts = _validation_utils.get_abcd_layouts - -logging.basicConfig(level=logging.INFO) - - -class GemmMultiDKernelBuilder: - def __init__( - self, - working_path, - gpu_target, - datatype, - layout, - elementwise_function, - config_json=None, - ): - self.working_path = Path(working_path) - self.gpu_target = gpu_target - self.datatype = datatype - self.layout = layout - self.elementwise_function = elementwise_function - self.config_json = config_json - - # Create working directory if it doesn't exist - self.working_path.mkdir(parents=True, exist_ok=True) - - # Load configuration - if config_json and os.path.exists(config_json): - with open(config_json, "r") as f: - self.config = json.load(f) - - def write_kernel_list(self): - """Write kernel list to file for CMake to read (with comprehensive validation)""" - # Get configurations using comprehensive validation - tile_configs = self._get_tile_configs(fast_mode=False) - trait_combos = self._generate_trait_combinations() - - kernel_list = [] - for tile_config in tile_configs: - for trait_combo in trait_combos: - ( - pipeline, - epilogue, - scheduler, - pad_m, - pad_n, - pad_k, - persistent, - ) = trait_combo - - # Create kernel name with proper boolean capitalization - kernel_name = f"gemm_multi_d_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" - - # Create tile configuration string - tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" - tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" - tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" - - kernel_name += f"_{tile_str}" - - kernel_list.append( - { - "name": kernel_name, - "tile_config": tile_config, - "trait_combo": trait_combo, - } - ) - - # Write kernel count - with open(self.working_path / "gemm_multi_d_kernel_count.txt", "w") as f: - f.write(str(len(kernel_list))) - - # Write kernel list - with open(self.working_path / "gemm_multi_d_kernel_list.txt", "w") as f: - for kernel in kernel_list: - # Format: kernel_name|tile_config|trait_combo - tile_config = kernel["tile_config"] - trait_combo = kernel["trait_combo"] - - tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" - tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" - tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" - - trait_str = ( - f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_" - + "_".join(str(x) for x in trait_combo[3:]) - ) - - f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n") - - print(f"Listed {len(kernel_list)} kernel configurations") - - def _get_tile_configs(self, fast_mode=False): - """Get tile configurations for the current datatype and layout""" - tile_config = self.config["tile_config"] - - # Generate values in the config if default range is given - if tile_config.get("tile_m").get("values") is None: - tile_config.get("tile_m")["values"] = self._generate_values( - tile_config.get("tile_m").get("min"), - tile_config.get("tile_m").get("max"), - tile_config.get("tile_m").get("step"), - ) - if tile_config.get("tile_n").get("values") is None: - tile_config.get("tile_n")["values"] = self._generate_values( - tile_config.get("tile_n").get("min"), - tile_config.get("tile_n").get("max"), - tile_config.get("tile_n").get("step"), - ) - if tile_config.get("tile_k").get("values") is None: - tile_config.get("tile_k")["values"] = self._generate_values( - tile_config.get("tile_k").get("min"), - tile_config.get("tile_k").get("max"), - tile_config.get("tile_k").get("step"), - ) - - # Get all possible values for each parameter - tile_m_values = tile_config.get("tile_m").get("values") - tile_n_values = tile_config.get("tile_n").get("values") - tile_k_values = tile_config.get("tile_k").get("values") - warp_m_values = tile_config.get("warp_m").get("values") - warp_n_values = tile_config.get("warp_n").get("values") - warp_k_values = tile_config.get("warp_k").get("values") - warp_tile_m_values = tile_config.get("warp_tile_m").get("values") - warp_tile_n_values = tile_config.get("warp_tile_n").get("values") - warp_tile_k_values = tile_config.get("warp_tile_k").get("values") - - # Generate all combinations - configs = [] - for tile_m in tile_m_values: - for tile_n in tile_n_values: - for tile_k in tile_k_values: - for warp_m in warp_m_values: - for warp_n in warp_n_values: - for warp_k in warp_k_values: - for warp_tile_m in warp_tile_m_values: - for warp_tile_n in warp_tile_n_values: - for warp_tile_k in warp_tile_k_values: - # Validate configuration - if self._validate_tile_config( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - fast_mode=fast_mode, - ): - configs.append( - { - "tile_m": tile_m, - "tile_n": tile_n, - "tile_k": tile_k, - "warp_m": warp_m, - "warp_n": warp_n, - "warp_k": warp_k, - "warp_tile_m": warp_tile_m, - "warp_tile_n": warp_tile_n, - "warp_tile_k": warp_tile_k, - } - ) - return configs - - def _generate_values(self, min_val, max_val, step): - """Generate a list of values from min to max with the given step""" - values = [] - val = min_val - while val <= max_val: - values.append(val) - val += step - return values - - def _validate_tile_config( - self, - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - pipeline="compv4", # Default pipeline for validation - fast_mode=False, # Add fast mode option - ): - """Validate that tile configuration is reasonable""" - if fast_mode: - # Fast validation for listing - only basic sanity checks - if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: - return False - if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: - return False - if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: - return False - - # Basic divisibility check - if tile_m % (warp_m * warp_tile_m) != 0: - return False - if tile_n % (warp_n * warp_tile_n) != 0: - return False - if tile_k % (warp_k * warp_tile_k) != 0: - return False - - return True - else: - # Full validation for generation - # Determine data types for validation - a_datatype = self.datatype - b_datatype = self.datatype - c_datatype = self.datatype - - layout = self.layout - - # Special handling for certain data types - if self.datatype in ["fp8", "bf8"]: - c_datatype = "fp16" - - # Use the comprehensive validation function - return is_tile_config_valid( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - a_datatype, - b_datatype, - c_datatype, - pipeline, - layout, - self.gpu_target, - ) - - def _generate_trait_combinations(self): - """Generate all combinations of traits""" - - trait_config = self.config["trait_config"] - - pipelines = trait_config.get("pipeline").get("values") - epilogues = trait_config.get("epilogue").get("values") - schedulers = trait_config.get("scheduler").get("values") - pad_m_values = trait_config.get("pad_m").get("values") - pad_n_values = trait_config.get("pad_n").get("values") - pad_k_values = trait_config.get("pad_k").get("values") - persistent_values = trait_config.get("persistent").get("values") - - all_combinations = list( - itertools.product( - pipelines, - epilogues, - schedulers, - pad_m_values, - pad_n_values, - pad_k_values, - persistent_values, - ) - ) - - # Filter out unsupported trait combinations - combinations = [] - for combo in all_combinations: - pipeline, epilogue, scheduler = combo[:3] - if is_trait_combination_valid(pipeline, epilogue, scheduler): - combinations.append(combo) - else: - logging.debug( - f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}" - ) - return combinations - - def _generate_kernel_instance( - self, tile_config, trait_combo, k_block_per_cu, is_header=True - ): - """Generate a single kernel instance""" - ( - pipeline, - epilogue, - scheduler, - pad_m, - pad_n, - pad_k, - persistent, - ) = trait_combo - - # Create kernel name with proper boolean capitalization - kernel_name = f"gemm_multi_d_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" - - # Create tile configuration string - tile_str = ( - f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" - ) - tile_str += ( - f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" - ) - tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" - - kernel_name += f"_{tile_str}" - - # Map pipeline names to the correct pipeline implementation - pipeline_impl_map = { - "mem": "ck_tile::GemmPipelineAgBgCrMem", - "compv3": "ck_tile::GemmPipelineAgBgCrCompV3", - "compv4": "ck_tile::GemmPipelineAgBgCrCompV4", - } - - # Map pipeline names to base pipeline for hot loop detection - base_pipeline_map = { - "mem": "ck_tile::BaseGemmPipelineAgBgCrMem", - "compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3", - "compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4", - } - - # Map scheduler names to the correct enum values - scheduler_type_map = { - "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", - "interwave": "ck_tile::GemmPipelineScheduler::Interwave", - "default": "ck_tile::GemmPipelineScheduler::Default", - } - - # Determine accumulator type based on datatype - acc_type = "float" - - # Determine output type - c_type = self.datatype - if self.datatype in ["fp8", "bf8"]: - c_type = "fp16" - - # Determine layouts based on self.layout - a_layout, b_layout, c_layout, ds_layout = get_abcd_layouts(self.layout) - - # Generate kernel instance code using the correct API - pragma_line = "#pragma once\n" if is_header else "" - instance_code = f"""// Generated kernel instance for {kernel_name} -{pragma_line} -#include -#include -#include -#include "ck_tile/core.hpp" -#include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" -#include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" -#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" - -using ADataType = {get_dtype_string(self.datatype)}; -using BDataType = {get_dtype_string(self.datatype)}; -using AccDataType = {acc_type}; -using CDataType = {get_dtype_string(c_type)}; -using D0DataType = {get_dtype_string(self.datatype)}; -using D1DataType = {get_dtype_string(self.datatype)}; -using DsDataType = ck_tile::tuple; - -using ALayout = {a_layout}; -using BLayout = {b_layout}; -using CLayout = {c_layout}; -using D0Layout = {ds_layout[0]}; -using D1Layout = {ds_layout[1]}; -using DsLayout = ck_tile::tuple; - -using ElementWiseFn = ck_tile::element_wise::{self.elementwise_function}; - -// Kernel name for display -constexpr const char* KERNEL_NAME = "{kernel_name}"; - -// Wrapper for simplified launch interface -struct SelectedKernel {{ - // Tile configuration - static constexpr ck_tile::index_t BlockSize = 256; - static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]}; - static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]}; - static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]}; - static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]}; - static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]}; - static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]}; - static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]}; - static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]}; - static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]}; - - // Traits - static constexpr bool kPadM = {"true" if pad_m in [True, "true"] else "false"}; - static constexpr bool kPadN = {"true" if pad_n in [True, "true"] else "false"}; - static constexpr bool kPadK = {"true" if pad_k in [True, "true"] else "false"}; - - static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"}; - static constexpr bool TransposeC = false; - - // Tile shape - using TileShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile::sequence>; - - // Tile partitioner - using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; - - // Traits - using Traits = ck_tile::TileGemmTraits; - - // Pipeline problem - using GemmPipelineProblem = ck_tile::GemmPipelineProblem< - ADataType, - BDataType, - AccDataType, - TileShape, - Traits>; - - // Base pipeline for hot loop detection - using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}; - - static float launch(const ck_tile::GemmMultiDHostArgs& args, const ck_tile::stream_config& stream) {{ - constexpr auto scheduler = {scheduler_type_map.get(scheduler)}; - - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - ADataType, - BDataType, - AccDataType, - TileShape, - ck_tile::TileGemmUniversalTraits, - scheduler>; - - using GemmPipeline = {pipeline_impl_map.get(pipeline)}; - - const auto Run = [&](const auto memory_operation_) {{ - [[maybe_unused]] constexpr auto memory_operation = memory_operation_.value; - - // Epilogue -""" - - # Add epilogue configuration based on type - if epilogue == "cshuffle": - instance_code += """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< - ADataType, - BDataType, - DsDataType, - AccDataType, - CDataType, - DsLayout, - CLayout, - ElementWiseFn, - TilePartitioner::MPerBlock, // kM_ - TilePartitioner::NPerBlock, // kN_ - WarpPerBlock_M, // MWave_ - WarpPerBlock_N, // NWave_ - WarpTileM, // MPerXdl_ - WarpTileN, // NPerXdl_ - WarpTileK, // KPerXdl_ - TransposeC, // isCTransposed_ - memory_operation>; // MemoryOperation_ - - using GemmEpilogue = ck_tile::CShuffleEpilogue; -""" - else: # default epilogue - instance_code += """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< - ADataType, - BDataType, - DsDataType, - AccDataType, - CDataType, - DsLayout, - CLayout, - ElementWiseFn, - TilePartitioner::MPerBlock, // kM_ - TilePartitioner::NPerBlock, // kN_ - kPadM, - kPadN, - WarpTileM, // kMPerXdl_ - WarpTileN, // kNPerXdl_ - WarpTileK, // kKPerXdl_ - TransposeC>; // isCTransposed_ - - using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue; -""" - - instance_code += f""" - - // Kernel type - using GemmKernelMultiD = ck_tile::GemmKernelMultiD; - - // Make kernel arguments - auto kargs = GemmKernelMultiD::MakeKernelArgs(args); - - if (!GemmKernelMultiD::IsSupportedArgument(kargs)) {{ - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); - }} - - // Get grid and block sizes - const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = GemmKernelMultiD::BlockSize(); - - if(stream.log_level_ > 0) {{ - std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n' - << "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" - << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" - << std::endl; - }} - - // Launch kernel - constexpr int kBlockPerCu = {k_block_per_cu}; - return ck_tile::launch_kernel( - stream, - ck_tile::make_kernel(GemmKernelMultiD{{}}, grids, blocks, 0, kargs)); - }}; - - if(args.k_batch == 1) {{ - return Run(ck_tile::integral_constant{{}}); - }} else {{ - return Run(ck_tile::integral_constant{{}}); - }} - }} -}}; -""" - return kernel_name, instance_code - - def run(self, num_workers=None): - """Run the builder to generate individual kernel files""" - # Generate individual kernel files - self.generate_individual(num_workers) - - def generate_individual(self, num_workers=None): - """Generate individual kernel files for separate compilation with parallel processing""" - if num_workers is None: - num_workers = min( - multiprocessing.cpu_count(), 8 - ) # Limit to avoid memory issues - - tile_configs = self._get_tile_configs() - trait_combos = self._generate_trait_combinations() - k_block_per_cu = self.config.get("k_block_per_cu") - if k_block_per_cu is None: - k_block_per_cu = 1 - - # Prepare work items for parallel processing - work_items = [] - for tile_config in tile_configs: - for trait_combo in trait_combos: - work_items.append( - ( - tile_config, - trait_combo, - k_block_per_cu, - self.working_path, - self.gpu_target, - self.datatype, - self.layout, - self.elementwise_function, - self.config_json, - ) - ) - - print( - f"Generating {len(work_items)} individual kernel files using {num_workers} workers..." - ) - print(f" Tile configs: {len(tile_configs)}") - print(f" Trait combinations: {len(trait_combos)}") - print(f" Total kernels: {len(work_items)}") - - # Show first few work items for debugging - if work_items: - print(" First work item example:") - tile_config, trait_combo = work_items[0][:2] - print(f" Tile config: {tile_config}") - print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits - - # Process work items in parallel - kernel_list = [] - completed = 0 - - with concurrent.futures.ProcessPoolExecutor( - max_workers=num_workers - ) as executor: - # Submit all work items - print(f" Submitting {len(work_items)} tasks to executor...") - future_to_item = { - executor.submit(_generate_single_kernel_individual, item): item - for item in work_items - } - print(" All tasks submitted, waiting for completion...") - - # Collect results with progress reporting - for future in concurrent.futures.as_completed(future_to_item): - completed += 1 - if completed % 100 == 0 or completed == len(work_items): - print( - f" Progress: {completed}/{len(work_items)} kernels generated" - ) - - try: - result = future.result() - if result: - kernel_list.append(result) - except Exception as exc: - item = future_to_item[future] - print(f"Kernel generation failed for {item}: {exc}") - - # Sort kernel list for consistent ordering - kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name - - # Generate CMake include file for individual targets - self._generate_cmake_individual_targets(kernel_list) - - print( - f"Generated {len(kernel_list)} individual kernel files in {self.working_path}" - ) - - def _generate_cmake_individual_targets(self, kernel_list): - """Generate CMake include file that creates individual targets""" - cmake_code = f"""# Generated CMake file for individual GEMM Multi D targets - # Datatype: {self.datatype}, Layout: {self.layout} - """ - - for kernel_name, trait_combo, tile_config in kernel_list: - pipeline, epilogue, scheduler = trait_combo[:3] - - # Format tile config for CMake function - tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" - tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" - tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" - - trait_str = f"{pipeline}_{epilogue}_{scheduler}_" + "_".join( - str(x) for x in trait_combo[3:] - ) - - cmake_code += f'create_individual_gemm_multi_d_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n' - - # Write CMake include file - with open( - self.working_path / "gemm_multi_d_individual_targets.cmake", "w" - ) as f: - f.write(cmake_code) - - -def _generate_single_kernel_individual(work_item): - """Worker function to generate a single individual kernel file""" - ( - tile_config, - trait_combo, - k_block_per_cu, - working_path, - gpu_target, - datatype, - layout, - elementwise_function, - config_json, - ) = work_item - - # Create a temporary builder instance for this worker - builder = GemmMultiDKernelBuilder( - working_path, - gpu_target, - datatype, - layout, - elementwise_function, - config_json, - ) - - try: - kernel_name, instance_code = builder._generate_kernel_instance( - tile_config, trait_combo, k_block_per_cu - ) - - # Create simplified filename without the "gemm_multi_d_" prefix - # Remove "gemm_multi_d_" from the beginning of kernel_name for the filename - simplified_name = kernel_name - if simplified_name.startswith("gemm_multi_d_"): - simplified_name = simplified_name[13:] # Remove "gemm_multi_d_" prefix - - # Write individual header file - header_file = working_path / f"gemm_multi_d_single_{simplified_name}.hpp" - with open(header_file, "w") as f: - f.write(instance_code) - - return (kernel_name, trait_combo, tile_config) - except Exception as e: - print(f"Error generating individual kernel: {e}") - return None - - -def main(): - parser = argparse.ArgumentParser( - description="GEMM Multi D kernel instance builder with parallel support" - ) - parser.add_argument("--working_path", required=True, help="Working directory path") - parser.add_argument("--gpu_target", required=True, help="GPU target architecture") - parser.add_argument( - "--datatype", - required=True, - choices=["fp16"], - help="Data type", - ) - parser.add_argument( - "--layout", - required=True, - choices=["rcrr", "rrrr", "ccrr", "crrr"], - help="Matrix layout", - ) - parser.add_argument( - "--elementwise_function", - required=True, - help="Specify what element wise function for D, e.g. mul, add, passthrough", - ) - parser.add_argument("--config_json", help="Configuration JSON file") - parser.add_argument( - "--num_workers", type=int, help="Number of parallel workers (default: auto)" - ) - parser.add_argument( - "--gen_all_individual", - action="store_true", - help="Generate individual kernel files", - ) - parser.add_argument( - "--gen_single", action="store_true", help="Generate a single kernel file" - ) - parser.add_argument("--kernel_name", help="Kernel name for single generation") - parser.add_argument( - "--tile_config", help="Tile configuration string for single generation" - ) - parser.add_argument( - "--trait_combo", help="Trait combination string for single generation" - ) - parser.add_argument( - "--list_kernels", - action="store_true", - help="List kernel configurations without generating files", - ) - - args = parser.parse_args() - - assert args.datatype in ["fp16"], ( - f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16])" - ) - - layout_parts = args.layout.lower() - assert len(layout_parts) == 4, ( - f"Invalid layout string: {args.layout} (must be 4 characters like 'rcrr' where r stands for row major and c stands for column major)" - ) - assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], ( - f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)" - ) - assert layout_parts[2] == "r" and layout_parts[3] == "r", ( - f"Invalid matrix_c or d dimension in layout: {layout_parts[2]} andf {layout_parts[3]} (must be 'r' only as currently we are supporting only row major)" - ) - - # Elementwise function name validation - elementwise_function = args.elementwise_function.lower() - - valid_functions = ["mul", "add", "passthrough"] - if elementwise_function not in valid_functions: - raise ValueError( - f"Invalid elementwise function: {elementwise_function}. " - f"Valid options are: {', '.join(valid_functions)}" - ) - - # Set the function name based on the elementwise function - if elementwise_function == "mul": - function_name = "MultiDMultiply" - elif elementwise_function == "add": - function_name = "MultiDAdd" - elif elementwise_function == "passthrough": - function_name = "PassThrough" - - args.elementwise_function = function_name - - # Create builder - builder = GemmMultiDKernelBuilder( - args.working_path, - args.gpu_target, - args.datatype, - args.layout, - args.elementwise_function, - args.config_json, - ) - - if args.list_kernels: - builder.write_kernel_list() - elif args.gen_single: - # Generate a single kernel file - if not args.kernel_name or not args.tile_config or not args.trait_combo: - parser.error( - "--gen_single requires --kernel_name, --tile_config, and --trait_combo" - ) - - # Parse tile config - tile_parts = args.tile_config.split("_") - tile_dims = tile_parts[0].split("x") - warp_dims = tile_parts[1].split("x") - warp_tile_dims = tile_parts[2].split("x") - - tile_config = { - "tile_m": int(tile_dims[0]), - "tile_n": int(tile_dims[1]), - "tile_k": int(tile_dims[2]), - "warp_m": int(warp_dims[0]), - "warp_n": int(warp_dims[1]), - "warp_k": int(warp_dims[2]), - "warp_tile_m": int(warp_tile_dims[0]), - "warp_tile_n": int(warp_tile_dims[1]), - "warp_tile_k": int(warp_tile_dims[2]), - } - - # Parse trait combo - trait_parts = args.trait_combo.split("_") - trait_combo = ( - trait_parts[0], # pipeline - trait_parts[1], # epilogue - trait_parts[2], # scheduler - trait_parts[3] == "True", # pad_m - trait_parts[4] == "True", # pad_n - trait_parts[5] == "True", # pad_k - trait_parts[6] == "True", # persistent - ) - - k_block_per_cu = builder.config.get("k_block_per_cu") - if k_block_per_cu is None: - k_block_per_cu = 1 - - # Generate the kernel - kernel_name, instance_code = builder._generate_kernel_instance( - tile_config, trait_combo, k_block_per_cu - ) - - # Write the file - simplified_name = kernel_name - if simplified_name.startswith("gemm_multi_d_"): - simplified_name = simplified_name[13:] - - header_file = ( - builder.working_path / f"gemm_multi_d_single_{simplified_name}.hpp" - ) - with open(header_file, "w") as f: - f.write(instance_code) - - print(f"Generated {header_file}") - - elif args.gen_all_individual: - # Generate all individual kernel files - builder.run(args.num_workers) - else: - parser.error( - "Must specify one of: --list_kernels, --gen_all_individual, or --gen_single" - ) - - -if __name__ == "__main__": - main() diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp deleted file mode 100644 index 8c0c5f78d4..0000000000 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" -#include "ck_tile/core/numeric/integer.hpp" -#include "ck_tile/core/numeric/pk_int4.hpp" - -// Helper function to determine if a layout is row-major -template -constexpr auto is_row_major(Layout) -{ - return ck_tile::bool_constant>{}; -} - -// Structure to hold kernel traits for dispatcher -struct KernelTraits -{ - std::string pipeline; // preshufflev2 - std::string scheduler; // intrawave, interwave, default - std::string epilogue; // cshuffle, default - bool pad_m; - bool pad_n; - bool pad_k; - bool persistent; - - // Constructor with defaults - KernelTraits() - : pipeline("preshufflev2"), - scheduler("default"), - epilogue("default"), - pad_m(false), - pad_n(false), - pad_k(false), - persistent(false) - { - } -}; - -// Helper to extract traits from kernel name -inline KernelTraits extract_traits_from_name(const std::string& kernel_name) -{ - KernelTraits traits; - - // Extract pipeline - if(kernel_name.find("preshufflev2") != std::string::npos) - { - traits.pipeline = "preshufflev2"; - } - - // Extract scheduler - if(kernel_name.find("interwave") != std::string::npos) - { - traits.scheduler = "interwave"; - } - else if(kernel_name.find("intrawave") != std::string::npos) - { - traits.scheduler = "intrawave"; - } - else - { - traits.scheduler = "default"; - } - - // Extract epilogue - if(kernel_name.find("default") != std::string::npos && - kernel_name.find("default_") == std::string::npos) - { - traits.epilogue = "default"; - } - else - { - traits.epilogue = "cshuffle"; - } - - // Padding flags would need to be extracted from the kernel configuration - // For now, we'll leave them as false - - return traits; -} diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py deleted file mode 100644 index 62c239590a..0000000000 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py +++ /dev/null @@ -1,894 +0,0 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -import argparse -import os -import json -import itertools -import logging -import multiprocessing -import concurrent.futures -from pathlib import Path -import importlib.util - - -def _import_validation_utils(): - """Import validation utilities from commons directory.""" - current_dir = os.path.dirname(os.path.abspath(__file__)) - parent_dir = os.path.dirname(current_dir) - - # Load the module dynamically - spec = importlib.util.spec_from_file_location( - "validation_utils", - os.path.join(parent_dir, "commons", "gemm_validation_utils.py"), - ) - validation_utils = importlib.util.module_from_spec(spec) - spec.loader.exec_module(validation_utils) - - return validation_utils - - -# Import validation functions -_validation_utils = _import_validation_utils() -is_tile_config_valid = _validation_utils.is_tile_config_valid -is_trait_combination_valid = _validation_utils.is_trait_combination_valid -get_dtype_string = _validation_utils.get_dtype_string -get_abc_layouts = _validation_utils.get_abc_layouts - -logging.basicConfig(level=logging.INFO) - - -class GemmPreshuffleKernelBuilder: - def __init__(self, working_path, gpu_target, datatype, layout, config_json=None): - self.working_path = Path(working_path) - self.gpu_target = gpu_target - self.datatype = datatype - self.layout = layout - self.config_json = config_json - - # Create working directory if it doesn't exist - self.working_path.mkdir(parents=True, exist_ok=True) - - # Load configuration - if config_json and os.path.exists(config_json): - with open(config_json, "r") as f: - self.config = json.load(f) - - def write_kernel_list(self): - """Write kernel list to file for CMake to read (with comprehensive validation)""" - # Get configurations using comprehensive validation - tile_configs = self._get_tile_configs(fast_mode=False) - trait_combos = self._generate_trait_combinations() - - kernel_list = [] - for tile_config in tile_configs: - for trait_combo in trait_combos: - ( - pipeline, - epilogue, - scheduler, - pad_m, - pad_n, - pad_k, - persistent, - ) = trait_combo - - # Create kernel name with proper boolean capitalization - kernel_name = f"gemm_preshuffle_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" - - # Create tile configuration string - tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" - tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" - tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" - - kernel_name += f"_{tile_str}" - - kernel_list.append( - { - "name": kernel_name, - "tile_config": tile_config, - "trait_combo": trait_combo, - } - ) - - # Write kernel count - with open(self.working_path / "gemm_preshuffle_kernel_count.txt", "w") as f: - f.write(str(len(kernel_list))) - - # Write kernel list - with open(self.working_path / "gemm_preshuffle_kernel_list.txt", "w") as f: - for kernel in kernel_list: - # Format: kernel_name|tile_config|trait_combo - tile_config = kernel["tile_config"] - trait_combo = kernel["trait_combo"] - - tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" - tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" - tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" - - trait_str = ( - f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_" - + "_".join(str(x) for x in trait_combo[3:]) - ) - - f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n") - - print(f"Listed {len(kernel_list)} kernel configurations") - - def _get_tile_configs(self, fast_mode=False): - """Get tile configurations for the current datatype and layout""" - - tile_config = self.config["tile_config"] - - # Generate values in the config if default range is given - if tile_config.get("tile_m").get("values") is None: - tile_config.get("tile_m")["values"] = self._generate_values( - tile_config.get("tile_m").get("min"), - tile_config.get("tile_m").get("max"), - tile_config.get("tile_m").get("step"), - ) - if tile_config.get("tile_n").get("values") is None: - tile_config.get("tile_n")["values"] = self._generate_values( - tile_config.get("tile_n").get("min"), - tile_config.get("tile_n").get("max"), - tile_config.get("tile_n").get("step"), - ) - if tile_config.get("tile_k").get("values") is None: - tile_config.get("tile_k")["values"] = self._generate_values( - tile_config.get("tile_k").get("min"), - tile_config.get("tile_k").get("max"), - tile_config.get("tile_k").get("step"), - ) - - # Get all possible values for each parameter - tile_m_values = tile_config.get("tile_m").get("values") - tile_n_values = tile_config.get("tile_n").get("values") - tile_k_values = tile_config.get("tile_k").get("values") - warp_m_values = tile_config.get("warp_m").get("values") - warp_n_values = tile_config.get("warp_n").get("values") - warp_k_values = tile_config.get("warp_k").get("values") - warp_tile_m_values = tile_config.get("warp_tile_m").get("values") - warp_tile_n_values = tile_config.get("warp_tile_n").get("values") - warp_tile_k_values = tile_config.get("warp_tile_k").get("values") - - # Generate all combinations - configs = [] - for tile_m in tile_m_values: - for tile_n in tile_n_values: - for tile_k in tile_k_values: - for warp_m in warp_m_values: - for warp_n in warp_n_values: - for warp_k in warp_k_values: - for warp_tile_m in warp_tile_m_values: - for warp_tile_n in warp_tile_n_values: - for warp_tile_k in warp_tile_k_values: - # Validate configuration - if self._validate_tile_config( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - fast_mode=fast_mode, - ): - configs.append( - { - "tile_m": tile_m, - "tile_n": tile_n, - "tile_k": tile_k, - "warp_m": warp_m, - "warp_n": warp_n, - "warp_k": warp_k, - "warp_tile_m": warp_tile_m, - "warp_tile_n": warp_tile_n, - "warp_tile_k": warp_tile_k, - } - ) - return configs - - def _generate_values(self, min_val, max_val, step): - """Generate a list of values from min to max with the given step""" - values = [] - val = min_val - while val <= max_val: - values.append(val) - val += step - return values - - def _generate_trait_combinations(self): - """Generate all combinations of traits""" - if "traits" in self.config: - # Old format - traits = self.config["traits"] - pipelines = traits["pipelines"] - epilogues = traits["epilogues"] - schedulers = traits["schedulers"] - - padding = self.config["padding"] - persistent = self.config["persistent"] - - all_combinations = list( - itertools.product( - pipelines, - epilogues, - schedulers, - padding["pad_m"], - padding["pad_n"], - padding["pad_k"], - persistent, - ) - ) - - # Filter out unsupported trait combinations - combinations = [] - for combo in all_combinations: - pipeline, epilogue, scheduler = combo[:3] - if is_trait_combination_valid(pipeline, epilogue, scheduler): - combinations.append(combo) - else: - logging.debug( - f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}" - ) - - elif "trait_config" in self.config: - # New format - trait_config = self.config["trait_config"] - - pipelines = trait_config.get("pipeline", {}).get("values", ["preshufflev2"]) - epilogues = trait_config.get("epilogue", {}).get("values", ["default"]) - schedulers = trait_config.get("scheduler", {}).get("values", ["default"]) - pad_m_values = trait_config.get("pad_m", {}).get("values", [False]) - pad_n_values = trait_config.get("pad_n", {}).get("values", [False]) - pad_k_values = trait_config.get("pad_k", {}).get("values", [False]) - persistent_values = trait_config.get("persistent", {}).get( - "values", [False] - ) - - all_combinations = list( - itertools.product( - pipelines, - epilogues, - schedulers, - pad_m_values, - pad_n_values, - pad_k_values, - persistent_values, - ) - ) - - # Filter out unsupported trait combinations - combinations = [] - for combo in all_combinations: - pipeline, epilogue, scheduler = combo[:3] - if is_trait_combination_valid(pipeline, epilogue, scheduler): - combinations.append(combo) - else: - logging.debug( - f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}" - ) - else: - # Fallback to minimal default - combinations = [ - ("preshufflev2", "default", "default", False, False, False, False) - ] - - return combinations - - def _validate_tile_config( - self, - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - pipeline="preshufflev2", # Default pipeline for validation - fast_mode=False, # Add fast mode option - ): - """Validate that tile configuration is reasonable""" - if fast_mode: - # Fast validation for listing - only basic sanity checks - if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: - return False - if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: - return False - if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: - return False - - # Basic divisibility check - if tile_m % (warp_m * warp_tile_m) != 0: - return False - if tile_n % (warp_n * warp_tile_n) != 0: - return False - if tile_k % (warp_k * warp_tile_k) != 0: - return False - - return True - else: - # Validate preshuffle specific constraints - if self.config.get("permute_n"): - valid = (tile_n / warp_tile_n / warp_n) % 2 == 0 - if not valid: - return False - - # Full validation for generation - # Determine data types for validation - a_datatype = self.datatype - b_datatype = self.datatype - c_datatype = self.datatype - - layout = self.layout - - # Special handling for certain data types - if self.datatype in ["fp8", "bf8"]: - c_datatype = "fp16" - - # Use the comprehensive validation function - return is_tile_config_valid( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - a_datatype, - b_datatype, - c_datatype, - pipeline, - layout, - self.gpu_target, - ) - - def _generate_kernel_instance( - self, tile_config, trait_combo, k_block_per_cu, permute_n, is_header=True - ): - """Generate a single kernel instance""" - ( - pipeline, - epilogue, - scheduler, - pad_m, - pad_n, - pad_k, - persistent, - ) = trait_combo - - # Create kernel name with proper boolean capitalization - kernel_name = f"gemm_preshuffle_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" - - # Create tile configuration string - tile_str = ( - f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" - ) - tile_str += ( - f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" - ) - tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" - - kernel_name += f"_{tile_str}" - - # Map pipeline names to the correct pipeline implementation - pipeline_impl_map = { - "preshufflev2": "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2", - } - - # Map pipeline names to base pipeline for hot loop detection - base_pipeline_map = { - "preshufflev2": "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2", - } - - # Map scheduler names to the correct enum values - scheduler_type_map = { - "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", - "interwave": "ck_tile::GemmPipelineScheduler::Interwave", - "default": "ck_tile::GemmPipelineScheduler::Default", - } - - # Determine accumulator type based on datatype - acc_type = "float" - - # Determine output type - c_type = self.datatype - if self.datatype in ["fp8", "bf8"]: - c_type = "fp16" - - # Determine layouts based on self.layout - a_layout, b_layout, c_layout = get_abc_layouts(self.layout) - - # Generate kernel instance code using the correct API - pragma_line = "#pragma once\n" if is_header else "" - instance_code = f"""// Generated kernel instance for {kernel_name} -{pragma_line} -#include -#include -#include -#include "ck_tile/core.hpp" -#include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" -#include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" -#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" - -using ADataType = {get_dtype_string(self.datatype)}; -using BDataType = {get_dtype_string(self.datatype)}; -using AccDataType = {acc_type}; -using CDataType = {get_dtype_string(c_type)}; - -using ALayout = {a_layout}; -using BLayout = {b_layout}; -using CLayout = {c_layout}; - -// Kernel name for display -constexpr const char* KERNEL_NAME = "{kernel_name}"; - -// Wrapper for simplified launch interface -struct SelectedKernel {{ - // Tile configuration - static constexpr ck_tile::index_t BlockSize = 256; - static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]}; - static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]}; - static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]}; - static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]}; - static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]}; - static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]}; - static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]}; - static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]}; - static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]}; - - // Traits - static constexpr bool kPadM = {"true" if pad_m == "true" else "false"}; - static constexpr bool kPadN = {"true" if pad_n == "true" else "false"}; - static constexpr bool kPadK = {"true" if pad_k == "true" else "false"}; - static constexpr bool TransposeC = false; - static constexpr bool UsePersistentKernel = {"true" if persistent == "true" else "false"}; - static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "preshufflev2" else "false"}; - static constexpr bool UseStructuredSparsity = false; - static constexpr bool Preshuffle = true; - static constexpr ck_tile::index_t NumWaveGroups = 1; - - static constexpr bool PermuteN = {"true" if permute_n else "false"}; - - // Tile shape - using TileShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile::sequence, - false, false>; - - // Tile partitioner - using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; - - // Traits - using Traits = ck_tile::TileGemmTraits; - - // Pipeline problem - using GemmPipelineProblem = ck_tile::GemmPipelineProblem< - ADataType, - BDataType, - AccDataType, - TileShape, - Traits>; - - // Base pipeline for hot loop detection - using BaseGemmPipeline = {base_pipeline_map.get(pipeline, "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2")}; - - static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ - constexpr auto scheduler = {scheduler_type_map.get(scheduler, "ck_tile::GemmPipelineScheduler::Default")}; - - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - ADataType, - BDataType, - AccDataType, - TileShape, - ck_tile::TileGemmUniversalTraits, - scheduler>; - - using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2")}; - - const auto Run = [&](const auto memory_operation_) {{ - [[maybe_unused]] constexpr auto memory_operation = memory_operation_.value; - - // Epilogue -""" - - # Add epilogue configuration based on type - if epilogue == "cshuffle": - instance_code += """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< - ADataType, - BDataType, - ck_tile::tuple<>, // DsDataType - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, // kM_ - TilePartitioner::NPerBlock, // kN_ - WarpPerBlock_M, // MWave_ - WarpPerBlock_N, // NWave_ - WarpTileM, // MPerXdl_ - WarpTileN, // NPerXdl_ - WarpTileK, // KPerXdl_ - TransposeC, // isCTransposed_ - memory_operation, // MemoryOperation_ - NumWaveGroups, // kNumWaveGroups_ - false, // FixedVectorSize_ - 1, // VectorSizeC_ - PermuteN>; // isPermuteN_ - - using GemmEpilogue = ck_tile::CShuffleEpilogue; -""" - else: # default epilogue - instance_code += """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< - ADataType, - BDataType, - ck_tile::tuple<>, // DsDataType - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, // kM_ - TilePartitioner::NPerBlock, // kN_ - kPadM, - kPadN, - WarpTileM, // kMPerXdl_ - WarpTileN, // kNPerXdl_ - WarpTileK, // kKPerXdl_ - TransposeC>; // isCTransposed_ - - using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue; -""" - - instance_code += f""" - - // Kernel type - using GemmKernel = ck_tile::GemmKernel; - - // Make kernel arguments - auto kargs = GemmKernel::MakeKernelArgs(args); - - if (!GemmKernel::IsSupportedArgument(kargs)) {{ - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); - }} - - // Get grid and block sizes - const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent == "true" else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; - const dim3 blocks = GemmKernel::BlockSize(); - - if(stream.log_level_ > 0) {{ - std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n' - << "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" - << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" - << std::endl; - }} - - // Launch kernel - constexpr int kBlockPerCu = {k_block_per_cu}; - return ck_tile::launch_kernel( - stream, - ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); - }}; - - if(args.k_batch == 1) {{ - return Run(ck_tile::integral_constant{{}}); - }} else {{ - return Run(ck_tile::integral_constant{{}}); - }} - }} -}}; -""" - - return kernel_name, instance_code - - def run(self, num_workers=None): - """Run the builder to generate individual kernel files""" - # Generate individual kernel files - self.generate_individual(num_workers) - - def generate_individual(self, num_workers=None): - """Generate individual kernel files for separate compilation with parallel processing""" - if num_workers is None: - num_workers = min( - multiprocessing.cpu_count(), 8 - ) # Limit to avoid memory issues - - tile_configs = self._get_tile_configs() - trait_combos = self._generate_trait_combinations() - k_block_per_cu = self.config.get("k_block_per_cu") - permute_n = self.config.get("permute_n") - - # Prepare work items for parallel processing - work_items = [] - for tile_config in tile_configs: - for trait_combo in trait_combos: - work_items.append( - ( - tile_config, - trait_combo, - k_block_per_cu, - permute_n, - self.working_path, - self.datatype, - self.layout, - ) - ) - - print( - f"Generating {len(work_items)} individual kernel files using {num_workers} workers..." - ) - print(f" Tile configs: {len(tile_configs)}") - print(f" Trait combinations: {len(trait_combos)}") - print(f" Total kernels: {len(work_items)}") - - # Show first few work items for debugging - if work_items: - print(" First work item example:") - tile_config, trait_combo = work_items[0][:2] - print(f" Tile config: {tile_config}") - print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits - - # Process work items in parallel - kernel_list = [] - completed = 0 - - with concurrent.futures.ProcessPoolExecutor( - max_workers=num_workers - ) as executor: - # Submit all work items - print(f" Submitting {len(work_items)} tasks to executor...") - future_to_item = { - executor.submit(_generate_single_kernel_individual, item): item - for item in work_items - } - print(" All tasks submitted, waiting for completion...") - - # Collect results with progress reporting - for future in concurrent.futures.as_completed(future_to_item): - completed += 1 - if completed % 100 == 0 or completed == len(work_items): - print( - f" Progress: {completed}/{len(work_items)} kernels generated" - ) - - try: - result = future.result() - if result: - kernel_list.append(result) - except Exception as exc: - item = future_to_item[future] - print(f"Kernel generation failed for {item}: {exc}") - - # Sort kernel list for consistent ordering - kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name - - # Generate CMake include file for individual targets - self._generate_cmake_individual_targets(kernel_list) - - print( - f"Generated {len(kernel_list)} individual kernel files in {self.working_path}" - ) - - def _generate_cmake_individual_targets(self, kernel_list): - """Generate CMake include file that creates individual targets""" - cmake_code = f"""# Generated CMake file for individual GEMM Preshuffle targets -# Datatype: {self.datatype}, Layout: {self.layout} - -""" - - for kernel_name, trait_combo, tile_config in kernel_list: - pipeline, epilogue, scheduler = trait_combo[:3] - - # Format tile config for CMake function - tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" - tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" - tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" - - trait_str = f"{pipeline}_{epilogue}_{scheduler}_" + "_".join( - str(x) for x in trait_combo[3:] - ) - - cmake_code += f'create_individual_gemm_preshuffle_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n' - - # Write CMake include file - with open( - self.working_path / "gemm_preshuffle_individual_targets.cmake", "w" - ) as f: - f.write(cmake_code) - - -def _generate_single_kernel_individual(work_item): - """Worker function to generate a single individual kernel file""" - ( - tile_config, - trait_combo, - k_block_per_cu, - permute_n, - working_path, - datatype, - layout, - ) = work_item - - # Create a temporary builder instance for this worker - builder = GemmPreshuffleKernelBuilder(working_path, datatype, layout) - - try: - kernel_name, instance_code = builder._generate_kernel_instance( - tile_config, trait_combo, k_block_per_cu, permute_n - ) - - # Create simplified filename without the "gemm_preshuffle_" prefix - # Remove "gemm_preshuffle_" from the beginning of kernel_name for the filename - simplified_name = kernel_name - if simplified_name.startswith("gemm_preshuffle_"): - simplified_name = simplified_name[16:] # Remove "gemm_preshuffle_" prefix - - # Write individual header file - header_file = working_path / f"gemm_single_{simplified_name}.hpp" - with open(header_file, "w") as f: - f.write(instance_code) - - return (kernel_name, trait_combo, tile_config) - except Exception as e: - print(f"Error generating individual kernel: {e}") - return None - - -def main(): - parser = argparse.ArgumentParser( - description="GEMM kernel instance builder with parallel support" - ) - parser.add_argument("--working_path", required=True, help="Working directory path") - parser.add_argument( - "--gpu_target", - required=True, - help="GPU target architecture", - ) - parser.add_argument( - "--datatype", - required=True, - choices=["fp16", "fp8", "bf16", "bf8"], - help="Data type", - ) - parser.add_argument( - "--layout", - required=True, - choices=["rcr"], - help="Matrix layout", - ) - parser.add_argument("--config_json", required=True, help="Configuration JSON file") - parser.add_argument( - "--num_workers", type=int, help="Number of parallel workers (default: auto)" - ) - parser.add_argument( - "--gen_all_individual", - action="store_true", - help="Generate individual kernel files", - ) - parser.add_argument( - "--gen_single", action="store_true", help="Generate a single kernel file" - ) - parser.add_argument("--kernel_name", help="Kernel name for single generation") - parser.add_argument( - "--tile_config", help="Tile configuration string for single generation" - ) - parser.add_argument( - "--trait_combo", help="Trait combination string for single generation" - ) - parser.add_argument( - "--list_kernels", - action="store_true", - help="List kernel configurations without generating files", - ) - - args = parser.parse_args() - - assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], ( - f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])" - ) - - layout_parts = args.layout.lower() - assert len(layout_parts) == 3, ( - f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)" - ) - assert layout_parts[0] in ["r"] and layout_parts[1] in ["c"], ( - f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a must be 'r' for row major and matrix_b must be 'c' for column major as it is the only supported layout for preshuffle)" - ) - assert layout_parts[2] == "r", ( - f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)" - ) - - # Create builder - builder = GemmPreshuffleKernelBuilder( - args.working_path, args.gpu_target, args.datatype, args.layout, args.config_json - ) - - if args.list_kernels: - # Fast listing mode - just write kernel list without generating files - builder.write_kernel_list() - pass - elif args.gen_single: - # Generate a single kernel file - if not args.kernel_name or not args.tile_config or not args.trait_combo: - parser.error( - "--gen_single requires --kernel_name, --tile_config, and --trait_combo" - ) - # Parse tile config - tile_parts = args.tile_config.split("_") - tile_dims = tile_parts[0].split("x") - warp_dims = tile_parts[1].split("x") - warp_tile_dims = tile_parts[2].split("x") - - tile_config = { - "tile_m": int(tile_dims[0]), - "tile_n": int(tile_dims[1]), - "tile_k": int(tile_dims[2]), - "warp_m": int(warp_dims[0]), - "warp_n": int(warp_dims[1]), - "warp_k": int(warp_dims[2]), - "warp_tile_m": int(warp_tile_dims[0]), - "warp_tile_n": int(warp_tile_dims[1]), - "warp_tile_k": int(warp_tile_dims[2]), - } - - # Parse trait combo - trait_parts = args.trait_combo.split("_") - trait_combo = ( - trait_parts[0], # pipeline - trait_parts[1], # epilogue - trait_parts[2], # scheduler - trait_parts[3] == "True", # pad_m - trait_parts[4] == "True", # pad_n - trait_parts[5] == "True", # pad_k - trait_parts[6] == "True", # persistent - ) - - k_block_per_cu = builder.config.get("k_block_per_cu") - permute_n = builder.config.get("permute_n") - - # Generate the kernel - kernel_name, instance_code = builder._generate_kernel_instance( - tile_config, trait_combo, k_block_per_cu, permute_n - ) - - # Write the file - simplified_name = kernel_name - if simplified_name.startswith("gemm_preshuffle_"): - simplified_name = simplified_name[16:] - - header_file = ( - builder.working_path / f"gemm_preshuffle_single_{simplified_name}.hpp" - ) - with open(header_file, "w") as f: - f.write(instance_code) - - print(f"Generated {header_file}") - - elif args.gen_all_individual: - # Generate all individual kernel files - builder.run(args.num_workers) - pass - else: - parser.error( - "Must specify one of: --list_kernels, --gen_all_individual, or --gen_single" - ) - - -if __name__ == "__main__": - main()