From d405641f06162f2a6b1bf15f890caa7105beebe4 Mon Sep 17 00:00:00 2001 From: msaffari-amd Date: Mon, 3 Nov 2025 10:29:16 +0100 Subject: [PATCH 1/7] Ck tile engine gemm unit tests exapand test coverage (#3025) * initial commit for testing datatypes, layouts and traits * correct warp tile size for small datatype config to make a validate instance for fp16, bf16, fp8 * add tile size coverage test * Cover more tests, parallel instance generation, documentation * update cmakelist to run more tests * initial codes to support add test params in json file * add congurable problem sizes for different tests * modify README.md * clean test_gemm_simple code * correct padding coverage test * Add comprehensive and quick tile size config files * remove fp64 from datatypes * update documents. manage selecting tile_size config (quick or Comprehensive) * correct padding test problem sizes * update comprehensive test and correct documents * Skip GEMM tests with unsupported arguments instead of failing * change gen_single instead of gen_indivisual because of an issue. add splitk tests to tile_size_quick_config * clean CMakeList, remod py file * Refactor test configs: Rename tile_size to coverage, remove separate traits config, clean cmakefile, readme * update fp32, fp8 to test all layouts, clean documents and comments * limit fp32 test layouts to rcr because of compilation error on some gpus * remove fp32 because of the removing from gemm_instance_builder, make quick test smaller, updating comments * Fix fp8/bf8 test failures on gfx950 by adding OCP FP8 format support * Reduce quick_coverage test count from ~250 to ~144 for faster CI --- .../include/ck_tile/builder/CMakeLists.txt | 2 +- test/ck_tile/gemm_tile_engine/CMakeLists.txt | 272 ++++++++++++------ test/ck_tile/gemm_tile_engine/README.md | 58 ++++ .../comprehensive_coverage_config.json | 37 +++ .../configs/large_datatype_config.json | 34 +++ .../configs/padding_coverage_config.json | 34 +++ .../configs/quick_coverage_config.json | 34 +++ .../configs/simple_test_config.json | 107 ++----- .../configs/small_datatype_config.json | 35 +++ .../gemm_tile_engine/extract_test_params.py | 71 +++++ .../gemm_tile_engine/test_gemm_simple.cpp | 34 ++- 11 files changed, 545 insertions(+), 173 deletions(-) create mode 100644 test/ck_tile/gemm_tile_engine/configs/comprehensive_coverage_config.json create mode 100644 test/ck_tile/gemm_tile_engine/configs/large_datatype_config.json create mode 100644 test/ck_tile/gemm_tile_engine/configs/padding_coverage_config.json create mode 100644 test/ck_tile/gemm_tile_engine/configs/quick_coverage_config.json create mode 100644 test/ck_tile/gemm_tile_engine/configs/small_datatype_config.json create mode 100644 test/ck_tile/gemm_tile_engine/extract_test_params.py diff --git a/experimental/builder/include/ck_tile/builder/CMakeLists.txt b/experimental/builder/include/ck_tile/builder/CMakeLists.txt index f20b5d54ec..45723c3680 100644 --- a/experimental/builder/include/ck_tile/builder/CMakeLists.txt +++ b/experimental/builder/include/ck_tile/builder/CMakeLists.txt @@ -1 +1 @@ -# Empty placeholder until we add library code. +#Empty placeholder until we add library code. diff --git a/test/ck_tile/gemm_tile_engine/CMakeLists.txt b/test/ck_tile/gemm_tile_engine/CMakeLists.txt index 0174028c99..8ad0f2af75 100644 --- a/test/ck_tile/gemm_tile_engine/CMakeLists.txt +++ b/test/ck_tile/gemm_tile_engine/CMakeLists.txt @@ -32,43 +32,35 @@ function(create_individual_gemm_test_target datatype layout config_name trait ti set(target_name "test_gemm_tile_engine_${datatype}_${layout}_${config_name}_${trait}_${tile_config}") set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}") - # Generated header path for this specific kernel configuration + # Generated header path (already created during cmake configuration) set(test_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") + set(test_params_header "${working_path}/test_params.hpp") - # Generate kernel header using tile_engine's Python script - add_custom_command( - OUTPUT ${test_header} - COMMAND ${Python3_EXECUTABLE} ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py - --working_path ${working_path} - --gpu_target "${GEMM_TEST_GPU_TARGETS}" - --datatype ${datatype} - --layout ${layout} - --config_json ${config_json} - --gen_single - --kernel_name "test_gemm_${datatype}_${layout}_${trait}_${tile_config}" - --tile_config "${tile_config}" - --trait_combo "${trait}" - DEPENDS ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py ${config_json} - COMMENT "Generating test header ${test_header}" - VERBATIM - ) + # Verify header exists (should have been generated during cmake configuration) + if(NOT EXISTS ${test_header}) + message(WARNING "Generated header not found: ${test_header}") + return() + endif() + + # Verify test parameters header exists + if(NOT EXISTS ${test_params_header}) + message(WARNING "Test parameters header not found: ${test_params_header}") + return() + endif() + # Create GTest executable for this kernel configuration add_gtest_executable(${target_name} ${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_simple.cpp ) - # Ensure header is generated before compilation - set(header_target "${target_name}_header") - add_custom_target(${header_target} DEPENDS ${test_header}) - add_dependencies(${target_name} ${header_target}) - # Configure GPU architectures for HIP compilation set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_TEST_GPU_TARGETS}) - # Define preprocessor macros for generated header location + # Define preprocessor macros for generated header location and test parameters target_compile_definitions(${target_name} PRIVATE GEMM_SINGLE_INSTANCE_HPP="${test_header}" + GEMM_TEST_PARAMS_HPP="${test_params_header}" ) # Include directories for headers and dependencies @@ -87,6 +79,11 @@ function(create_individual_gemm_test_target datatype layout config_name trait ti -include ${test_header} # Auto-include generated header ) + # Add FP8 format definitions for proper data type interpretation + if(CK_USE_OCP_FP8) + target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8) + endif() + message(STATUS " Created test target: ${target_name}") endfunction() @@ -107,7 +104,6 @@ function(build_gemm_test_targets datatype layout config_name) # Locate and validate configuration file set(config_filename "${config_name}.json") set(json_blob "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config_filename}") - message(STATUS " Using test config: ${config_filename}") if(NOT EXISTS ${json_blob}) message(WARNING "Test config file not found: ${json_blob}") @@ -118,7 +114,6 @@ function(build_gemm_test_targets datatype layout config_name) file(MAKE_DIRECTORY ${working_path}) # STEP 1: Discovery phase - list all valid kernel configurations - message(STATUS " Listing kernel configurations for ${datatype}_${layout}...") execute_process( COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py --working_path ${working_path} @@ -134,32 +129,90 @@ function(build_gemm_test_targets datatype layout config_name) ) if(NOT ret EQUAL 0) - message(WARNING "Failed to list kernels for ${datatype}_${layout}: ${list_error}") + message(WARNING "Failed to list kernels for ${datatype}_${layout}_${config_name}: ${list_error}") return() endif() - # Validate kernel discovery results - if(EXISTS ${working_path}/gemm_kernel_count.txt) - file(READ ${working_path}/gemm_kernel_count.txt kernel_count) - string(STRIP "${kernel_count}" kernel_count) - message(STATUS " Found ${kernel_count} test configurations for ${datatype}_${layout}") - else() - message(WARNING "Kernel count file not found for ${datatype}_${layout}") + # Verify kernel list file was generated + if(NOT EXISTS ${working_path}/gemm_kernel_list.txt) + message(STATUS "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)") return() endif() - # STEP 2: Generation phase - create test targets for each discovered kernel - if(EXISTS ${working_path}/gemm_kernel_list.txt) - file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) - set(test_count 0) - foreach(line IN LISTS kernel_lines) - # Parse kernel specification format: kernel_name|tile_config|trait_combo - string(REPLACE "|" ";" parts "${line}") - list(LENGTH parts parts_len) - if(parts_len EQUAL 3) - list(GET parts 0 kernel_name) - list(GET parts 1 tile_config) - list(GET parts 2 trait_combo) + message(STATUS "Building tests for ${datatype}_${layout}_${config_name}") + + # STEP 2a: Extract test parameters from config + set(test_params_file "${working_path}/test_params.hpp") + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/extract_test_params.py + --config_file ${json_blob} + --output_file ${test_params_file} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE extract_ret + OUTPUT_VARIABLE extract_output + ERROR_VARIABLE extract_error + ) + + if(NOT extract_ret EQUAL 0) + message(WARNING "Failed to extract test parameters for ${datatype}_${layout}: ${extract_error}") + return() + endif() + + # STEP 2b: Header generation phase - generate headers using --gen_single + message(STATUS " Generating headers using --gen_single...") + + file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) + set(gen_count 0) + + foreach(line IN LISTS kernel_lines) + # Parse kernel specification format: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(LENGTH parts parts_len) + if(parts_len EQUAL 3) + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) + + # Generate header using --gen_single + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py + --working_path ${working_path} + --gpu_target "${GEMM_TEST_GPU_TARGETS}" + --datatype ${datatype} + --layout ${layout} + --config_json ${json_blob} + --gen_single + --kernel_name "${kernel_name}" + --tile_config "${tile_config}" + --trait_combo "${trait_combo}" + WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR} + RESULT_VARIABLE gen_ret + OUTPUT_VARIABLE gen_output + ERROR_VARIABLE gen_error + ) + + if(NOT gen_ret EQUAL 0) + message(WARNING "Failed to generate header for ${kernel_name}: ${gen_error}") + else() + math(EXPR gen_count "${gen_count} + 1") + endif() + endif() + endforeach() + + message(STATUS " Generated ${gen_count} headers for ${datatype}_${layout}") + + # STEP 3: Target creation phase - create test targets + message(STATUS " Creating test targets...") + file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) + set(test_count 0) + foreach(line IN LISTS kernel_lines) + # Parse kernel specification format: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(LENGTH parts parts_len) + if(parts_len EQUAL 3) + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) # Generate test target for this kernel configuration create_individual_gemm_test_target("${datatype}" "${layout}" "${config_name}" "${trait_combo}" "${tile_config}" "${json_blob}") @@ -167,12 +220,7 @@ function(build_gemm_test_targets datatype layout config_name) endif() endforeach() message(STATUS " Created ${test_count} test targets for ${datatype}_${layout}") - else() - message(WARNING "Kernel list file not found for ${datatype}_${layout}") - endif() -endfunction() - -# ============================================================================ +endfunction()# ============================================================================ # MAIN EXECUTION - Test Target Generation # ============================================================================ @@ -198,42 +246,100 @@ endif() message(STATUS "Building GEMM tile engine tests for GPU targets: ${GEMM_TEST_GPU_TARGETS}") -# ============================================================================ -# Test Configuration Matrix -# ============================================================================ + # Enable parallel compilation optimizations + # Set up job pools for better parallel compilation control + set_property(GLOBAL PROPERTY JOB_POOLS + compile_heavy=4 # Limit heavy compilations to prevent OOM + compile_normal=16 # Allow more parallel normal compilations + ) -# Available test configurations (minimal set for fast CI/testing) -set(TEST_CONFIGS - "simple_test_config" - # "medium_tiles_config" # Uncomment for broader testing -) - -# Data types for testing (core precision types) -set(TEST_DATATYPES "fp16" "bf16") -# Extended data type options: -# set(TEST_DATATYPES "fp16" "bf16" "fp32" "fp64" "int8") - -# Matrix layouts for testing (row-column-row is most common) -set(TEST_LAYOUTS "rcr") -# Extended layout options: -# set(TEST_LAYOUTS "rcr" "rrr" "ccr" "crr") + # Enable compiler cache if available and explicitly requested + # Disabled by default due to permission issues in CI environments + option(ENABLE_CCACHE_TESTS "Enable ccache for test compilation" OFF) + if(ENABLE_CCACHE_TESTS) + find_program(CCACHE_PROGRAM ccache) + if(CCACHE_PROGRAM) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) + message(STATUS "Using ccache for faster test compilation") + else() + message(WARNING "ccache requested but not found") + endif() + else() + message(STATUS "ccache disabled for tests (use -DENABLE_CCACHE_TESTS=ON to enable)") + endif() # ============================================================================ -# Test Target Generation Loop +# Test Configuration Matrix - Clean Focused Design # ============================================================================ -foreach(datatype IN LISTS TEST_DATATYPES) - foreach(layout IN LISTS TEST_LAYOUTS) - foreach(config IN LISTS TEST_CONFIGS) - set(CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config}.json") - if(EXISTS ${CONFIG_FILE}) - message(STATUS "Building tests for ${datatype}_${layout}_${config}") - build_gemm_test_targets("${datatype}" "${layout}" "${config}") - else() - message(WARNING "Config file not found: ${CONFIG_FILE}") - endif() +# All supported data types and layouts for comprehensive testing +# Note: fp64 not included (no MFMA hardware support) +set(TEST_DATATYPES "fp16;fp8;bf16;fp32") +set(TEST_LAYOUTS "rcr;rrr;ccr;crr") + +# ============================================================================ +# Test Target Generation - Datatype-Specific Categories +# ============================================================================ + +# 1. SMALL DATATYPES: Test optimized config for small data types (fp8, fp16, bf16) +# These data types can use larger warp tiles due to smaller memory footprint +set(SMALL_DATATYPE_CONFIG "small_datatype_config") +set(SMALL_DATATYPE_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${SMALL_DATATYPE_CONFIG}.json") +set(SMALL_DATATYPES "fp8;fp16;bf16") + +if(EXISTS ${SMALL_DATATYPE_CONFIG_FILE}) + message(STATUS "Processing small datatype config: ${SMALL_DATATYPE_CONFIG} (fp8, fp16, bf16)") + foreach(datatype IN LISTS SMALL_DATATYPES) + # fp8, fp16, bf16: testing all layouts (rcr, rrr, ccr, crr) + foreach(layout IN LISTS TEST_LAYOUTS) + build_gemm_test_targets("${datatype}" "${layout}" "${SMALL_DATATYPE_CONFIG}") endforeach() endforeach() -endforeach() +else() + message(WARNING "Small datatype config file not found: ${SMALL_DATATYPE_CONFIG_FILE}") +endif() -message(STATUS "GEMM tile engine tests configured for ${TEST_DATATYPES} with ${TEST_LAYOUTS} layouts using ${TEST_CONFIGS} configurations") +# 2. PADDING COVERAGE: Test padding combinations with fixed fp16/rcr configuration +# This focuses on padding behavior (pad_m, pad_n, pad_k) +set(PADDING_CONFIG "padding_coverage_config") +set(PADDING_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${PADDING_CONFIG}.json") + +if(EXISTS ${PADDING_CONFIG_FILE}) + message(STATUS "Processing padding config: ${PADDING_CONFIG} (fp16/rcr only)") + build_gemm_test_targets("fp16" "rcr" "${PADDING_CONFIG}") +else() + message(WARNING "Padding config file not found: ${PADDING_CONFIG_FILE}") +endif() + +# 3. COVERAGE LEVEL: Quick or comprehensive testing +# Quick: ~144 kernels with multiple tile sizes and trait combinations +# Comprehensive: Several thousand kernels with extensive tile sizes, warp configurations, and all trait combinations +set(COVERAGE_LEVEL "quick" CACHE STRING "Coverage level: quick or comprehensive") +set_property(CACHE COVERAGE_LEVEL PROPERTY STRINGS "quick" "comprehensive") + +if(COVERAGE_LEVEL STREQUAL "quick") + set(COVERAGE_CONFIG "quick_coverage_config") + set(COVERAGE_DESC "Quick - approximately 144 kernels with trait combinations") +elseif(COVERAGE_LEVEL STREQUAL "comprehensive") + set(COVERAGE_CONFIG "comprehensive_coverage_config") + set(COVERAGE_DESC "Comprehensive - several thousand kernels with extensive tile and trait coverage") +else() + message(FATAL_ERROR "Invalid COVERAGE_LEVEL: ${COVERAGE_LEVEL}. Must be 'quick' or 'comprehensive'") +endif() + +set(COVERAGE_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${COVERAGE_CONFIG}.json") + +if(EXISTS ${COVERAGE_CONFIG_FILE}) + message(STATUS "Processing coverage config: ${COVERAGE_LEVEL} - ${COVERAGE_DESC}") + build_gemm_test_targets("fp16" "rcr" "${COVERAGE_CONFIG}") +else() + message(WARNING "Coverage config file not found: ${COVERAGE_CONFIG_FILE}") +endif() +# ============================================================================ + + +message(STATUS "GEMM tile engine tests configured with datatype-specific design:") +message(STATUS " - Small datatypes: fp8/fp16/bf16 (all layouts)") +message(STATUS " - Padding coverage with fp16/rcr") +message(STATUS " - Coverage level: ${COVERAGE_LEVEL} (~144 kernels quick, several thousand comprehensive)") +message(STATUS " Use -DCOVERAGE_LEVEL=comprehensive for extensive testing") diff --git a/test/ck_tile/gemm_tile_engine/README.md b/test/ck_tile/gemm_tile_engine/README.md index d99b4115d3..87ce0c9fd0 100644 --- a/test/ck_tile/gemm_tile_engine/README.md +++ b/test/ck_tile/gemm_tile_engine/README.md @@ -17,11 +17,69 @@ JSON Config → tile_engine Python scripts → Generated Headers → Test Execut ``` - **`--list_kernels`**: Get available kernel configurations from JSON +- **`--gen_individual`**: Generate all kernel headers in parallel during CMake configuration - **`--gen_single`**: Generate individual kernel header for each configuration - **Same verification**: Uses tile_engine's adaptive error thresholds and reference calculations - **Same patterns**: Follows tile_engine's tensor initialization, stride calculation, and kernel launching +### Config-Specific Test Parameters +Each test configuration can specify optimized problem sizes in its JSON file: +- **`test_params.problem_sizes`**: Array of `{m, n, k, split_k}` configurations +- **CMake extraction**: `extract_test_params.py` generates config-specific test parameter files +- **Build integration**: Each test target uses parameters appropriate for its kernel configuration +- **Optimized testing**: Different configs test different problem sizes that showcase their strengths The key idea: **Unit tests that use tile_engine's exact kernel generation and verification methodology** instead of creating separate test infrastructure. + +## Test Configurations + +### 1. **Simple Test** (`simple_test_config.json`) +- **Purpose**: Basic functionality validation +- **Config**: 128x128x64, warp 2x2x1, warp_tile 16x16x16 +- **Traits**: compv3 + compv4 pipelines +- **Coverage**: ~2 kernels per datatype/layout + +### 2. **Small Datatype** (`small_datatype_config.json`) +- **Purpose**: Optimized for fp8/fp16/bf16 data types +- **Config**: 128x128x32, warp 2x2x1, warp_tile 32x32x16 +- **Traits**: compv3 pipeline only +- **Coverage**: All 4 layouts (rcr, rrr, ccr, crr) for fp8, fp16, bf16 + +### 3. **Padding Coverage** (`padding_coverage_config.json`) +- **Purpose**: Test padding behavior with all padding flags enabled +- **Config**: Fixed 64x64x32, warp 2x2x1, warp_tile 32x32x16 +- **Padding**: All enabled (pad_m=true, pad_n=true, pad_k=true) +- **Problem sizes**: Vector-aligned but not tile-aligned (104×104×56, 200×152×80, 152×200×64) +- **Coverage**: 1 kernel configuration testing padding with irregular sizes + +### 4. **Coverage Testing** (Quick or Comprehensive) +- **Purpose**: Comprehensive testing across tile sizes, warp configurations, and trait combinations +- **Quick** (`quick_coverage_config.json`): Approximately 144 kernels + - tile_m/n: [32, 64, 256], tile_k: [16, 32] + - warp config: 2×2×1, warp_tile 16×16×16 + - Traits: 3 pipelines × 2 epilogues × 2 schedulers (persistent=false only) + - Focused set testing trait combinations with multiple tile sizes +- **Comprehensive** (`comprehensive_coverage_config.json`): Several thousand kernels + - tile_m/n: [16-256 step 16] + - tile_k: [16, 32, 64] + - warp_m/n: [1, 2, 4], warp_tile_m/n: [16, 32], warp_tile_k: [16, 32] + - Traits: 3 pipelines × 2 epilogues × 2 schedulers × 2 persistent + - Extensive coverage across all tile sizes, warp configurations, and trait combinations + - Exact count varies based on validation filtering +- **Note**: Use CMake option `-DCOVERAGE_LEVEL=comprehensive` to enable comprehensive testing (default is quick) + +## Data Type Support +- ✅ **fp8, fp16, bf16**: Fully supported - all layouts (rcr, rrr, ccr, crr) +- ❌ **fp64**: Not supported (hardware MFMA limitation) +- ⏳ **fp32, bf8, pk-int4-t**: Not yet supported by gemm_instance_builder (will be added later) + +## Test Result Behavior + +Tests automatically handle unsupported configurations through runtime validation: +- **PASSED**: Kernel executed correctly with results within error thresholds ✅ +- **SKIPPED**: Kernel validation returned "Arguments not supported" (expected for certain problem sizes/configurations) ⚠️ +- **FAILED**: Actual error or incorrect computation results ❌ + +When a kernel's `IsSupportedArgument()` check fails (e.g., due to vector alignment requirements, dimension constraints, or padding limitations), the test is automatically skipped rather than failed. This allows comprehensive testing across various problem sizes while gracefully handling configurations that don't meet specific kernel requirements. diff --git a/test/ck_tile/gemm_tile_engine/configs/comprehensive_coverage_config.json b/test/ck_tile/gemm_tile_engine/configs/comprehensive_coverage_config.json new file mode 100644 index 0000000000..f2524e4a61 --- /dev/null +++ b/test/ck_tile/gemm_tile_engine/configs/comprehensive_coverage_config.json @@ -0,0 +1,37 @@ +{ + "problem": { + "description": "Comprehensive coverage testing - extensive tile size coverage (16-256, step 16) with multiple warp configurations and all trait combinations. Several thousand kernels." + }, + "test_params": { + "problem_sizes": [ + {"m": 512, "n": 512, "k": 256, "split_k": 1}, + {"m": 1024, "n": 512, "k": 512, "split_k": 1}, + {"m": 512, "n": 1024, "k": 512, "split_k": 1}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 1}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 2}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 4} + ] + }, + "tile_config": { + "tile_m": {"values": [16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256]}, + "tile_n": {"values": [16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256]}, + "tile_k": {"values": [16, 32, 64]}, + "warp_m": {"values": [1, 2, 4]}, + "warp_n": {"values": [1, 2, 4]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [16, 32]}, + "warp_tile_n": {"values": [16, 32]}, + "warp_tile_k": {"values": [8, 16, 32, 64, 128]} + }, + "trait_config": { + "pipeline": {"values": ["mem", "compv3", "compv4"]}, + "epilogue": {"values": ["default", "cshuffle"]}, + "scheduler": {"values": ["intrawave", "interwave"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [true, false]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/test/ck_tile/gemm_tile_engine/configs/large_datatype_config.json b/test/ck_tile/gemm_tile_engine/configs/large_datatype_config.json new file mode 100644 index 0000000000..e9fcb6fb80 --- /dev/null +++ b/test/ck_tile/gemm_tile_engine/configs/large_datatype_config.json @@ -0,0 +1,34 @@ +{ + "problem": { + "description": "Configuration optimized for large data types (fp32) with smaller warp tiles due to memory constraints" + }, + "test_params": { + "problem_sizes": [ + {"m": 512, "n": 512, "k": 128, "split_k": 1}, + {"m": 512, "n": 256, "k": 192, "split_k": 1}, + {"m": 256, "n": 384, "k": 192, "split_k": 1} + ] + }, + "tile_config": { + "tile_m": {"values": [256]}, + "tile_n": {"values": [128]}, + "tile_k": {"values": [32]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [16]}, + "warp_tile_n": {"values": [16]}, + "warp_tile_k": {"values": [16]} + }, + "trait_config": { + "pipeline": {"values": ["compv3"]}, + "epilogue": {"values": ["default"]}, + "scheduler": {"values": ["intrawave"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [false]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/test/ck_tile/gemm_tile_engine/configs/padding_coverage_config.json b/test/ck_tile/gemm_tile_engine/configs/padding_coverage_config.json new file mode 100644 index 0000000000..33bada839d --- /dev/null +++ b/test/ck_tile/gemm_tile_engine/configs/padding_coverage_config.json @@ -0,0 +1,34 @@ +{ + "problem": { + "description": "Padding coverage testing - fixed config with fp16/rcr, varying only padding combinations" + }, + "test_params": { + "problem_sizes": [ + {"m": 104, "n": 104, "k": 56, "split_k": 1}, + {"m": 200, "n": 152, "k": 80, "split_k": 1}, + {"m": 152, "n": 200, "k": 64, "split_k": 1} + ] + }, + "tile_config": { + "tile_m": {"values": [64]}, + "tile_n": {"values": [64]}, + "tile_k": {"values": [32]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [32]}, + "warp_tile_n": {"values": [32]}, + "warp_tile_k": {"values": [16]} + }, + "trait_config": { + "pipeline": {"values": ["compv3"]}, + "epilogue": {"values": ["default"]}, + "scheduler": {"values": ["intrawave"]}, + "pad_m": {"values": [true]}, + "pad_n": {"values": [true]}, + "pad_k": {"values": [true]}, + "persistent": {"values": [false]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/test/ck_tile/gemm_tile_engine/configs/quick_coverage_config.json b/test/ck_tile/gemm_tile_engine/configs/quick_coverage_config.json new file mode 100644 index 0000000000..dcc6e99aee --- /dev/null +++ b/test/ck_tile/gemm_tile_engine/configs/quick_coverage_config.json @@ -0,0 +1,34 @@ +{ + "problem": { + "description": "Quick coverage testing - tests multiple tile sizes with all trait combinations (pipelines, epilogues, schedulers). Approximately 144 kernels." + }, + "test_params": { + "problem_sizes": [ + {"m": 512, "n": 1024, "k": 512, "split_k": 1}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 2}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 4} + ] + }, + "tile_config": { + "tile_m": {"values": [32, 64, 256]}, + "tile_n": {"values": [32, 64, 256]}, + "tile_k": {"values": [16, 32]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [16]}, + "warp_tile_n": {"values": [16]}, + "warp_tile_k": {"values": [16]} + }, + "trait_config": { + "pipeline": {"values": ["mem", "compv3", "compv4"]}, + "epilogue": {"values": ["default", "cshuffle"]}, + "scheduler": {"values": ["intrawave", "interwave"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [false]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json b/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json index a4f32a1907..498ef9fa33 100644 --- a/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json +++ b/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json @@ -1,88 +1,33 @@ { + "problem": { + "description": "Basic functionality validation with moderate problem sizes" + }, + "test_params": { + "problem_sizes": [ + {"m": 256, "n": 256, "k": 128, "split_k": 1}, + {"m": 512, "n": 256, "k": 256, "split_k": 1}, + {"m": 256, "n": 512, "k": 256, "split_k": 1} + ] + }, "tile_config": { - "tile_m": { - "values": [ - 128 - ] - }, - "tile_n": { - "values": [ - 128 - ] - }, - "tile_k": { - "values": [ - 64 - ] - }, - "warp_m": { - "values": [ - 2 - ] - }, - "warp_n": { - "values": [ - 2 - ] - }, - "warp_k": { - "values": [ - 1 - ] - }, - "warp_tile_m": { - "values": [ - 16 - ] - }, - "warp_tile_n": { - "values": [ - 16 - ] - }, - "warp_tile_k": { - "values": [ - 16 - ] - } + "tile_m": {"values": [128]}, + "tile_n": {"values": [128]}, + "tile_k": {"values": [64]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [16]}, + "warp_tile_n": {"values": [16]}, + "warp_tile_k": {"values": [16]} }, "trait_config": { - "pipeline": { - "values": [ - "compv3", - "compv4" - ] - }, - "scheduler": { - "values": [ - "intrawave" - ] - }, - "epilogue": { - "values": [ - "default" - ] - }, - "pad_m": { - "values": [ - false - ] - }, - "pad_n": { - "values": [ - false - ] - }, - "pad_k": { - "values": [ - false - ] - }, - "persistent": { - "values": [ - false - ] - } + "pipeline": {"values": ["compv3", "compv4"]}, + "epilogue": {"values": ["default"]}, + "scheduler": {"values": ["intrawave"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [false]} }, "k_block_per_cu": 1, "permute_n": false diff --git a/test/ck_tile/gemm_tile_engine/configs/small_datatype_config.json b/test/ck_tile/gemm_tile_engine/configs/small_datatype_config.json new file mode 100644 index 0000000000..d0d9f99a0c --- /dev/null +++ b/test/ck_tile/gemm_tile_engine/configs/small_datatype_config.json @@ -0,0 +1,35 @@ +{ + "problem": { + "description": "Configuration optimized for small data types (fp8, fp16, bf16) with larger warp tiles" + }, + "test_params": { + "problem_sizes": [ + {"m": 512, "n": 512, "k": 256, "split_k": 1}, + {"m": 1024, "n": 512, "k": 512, "split_k": 1}, + {"m": 512, "n": 1024, "k": 512, "split_k": 1}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 1} + ] + }, + "tile_config": { + "tile_m": {"values": [128]}, + "tile_n": {"values": [128]}, + "tile_k": {"values": [32]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [32]}, + "warp_tile_n": {"values": [32]}, + "warp_tile_k": {"values": [16]} + }, + "trait_config": { + "pipeline": {"values": ["compv3"]}, + "epilogue": {"values": ["default"]}, + "scheduler": {"values": ["intrawave"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [false]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/test/ck_tile/gemm_tile_engine/extract_test_params.py b/test/ck_tile/gemm_tile_engine/extract_test_params.py new file mode 100644 index 0000000000..c82591e391 --- /dev/null +++ b/test/ck_tile/gemm_tile_engine/extract_test_params.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +import json +import argparse +import os +from pathlib import Path + + +def extract_test_params(config_file, output_file): + """Extract test parameters from config JSON and write to output file""" + + # Read config file + with open(config_file, "r") as f: + config = json.load(f) + + # Extract test parameters + test_params = [] + if "test_params" in config and "problem_sizes" in config["test_params"]: + test_params = config["test_params"]["problem_sizes"] + else: + # Default test parameters if none specified + test_params = [ + {"m": 256, "n": 256, "k": 128, "split_k": 1}, + {"m": 256, "n": 256, "k": 1024, "split_k": 1}, + {"m": 256, "n": 512, "k": 512, "split_k": 1}, + {"m": 512, "n": 256, "k": 512, "split_k": 1}, + ] + + # Write to output file in C++ format + output_dir = Path(output_file).parent + output_dir.mkdir(parents=True, exist_ok=True) + + with open(output_file, "w") as f: + f.write("// Generated test parameters for this configuration\n") + f.write("// This file is auto-generated during CMake configuration\n\n") + f.write("static const std::vector CONFIG_TEST_PARAMS = {\n") + + for i, params in enumerate(test_params): + comma = "," if i < len(test_params) - 1 else "" + f.write( + f" {{{params['m']}, {params['n']}, {params['k']}, {params['split_k']}}}{comma}\n" + ) + + f.write("};\n") + + print( + f"Extracted {len(test_params)} test parameters from {config_file} -> {output_file}" + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Extract test parameters from config JSON" + ) + parser.add_argument("--config_file", required=True, help="Input config JSON file") + parser.add_argument( + "--output_file", required=True, help="Output test parameters file" + ) + + args = parser.parse_args() + + if not os.path.exists(args.config_file): + print(f"Error: Config file not found: {args.config_file}") + return 1 + + extract_test_params(args.config_file, args.output_file) + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp b/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp index 439dd4f39b..2054136647 100644 --- a/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp +++ b/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp @@ -1,8 +1,14 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -// Unit tests for tile_engine generated GEMM kernels -// Tests kernel correctness using tile_engine's verification methodology +/** + * @file test_gemm_simple.cpp + * @brief Unit tests for GEMM kernels generated by gemm_instance_builder + * + * This test includes kernels generated during CMake configuration by + * gemm_instance_builder.py and tests them with problem sizes extracted + * from the corresponding JSON configuration files. + */ #include #include @@ -68,6 +74,11 @@ struct GemmTestParams int m, n, k, split_k; }; +// Include config-specific test parameters (after GemmTestParams struct is defined) +#ifdef GEMM_TEST_PARAMS_HPP +#include GEMM_TEST_PARAMS_HPP +#endif + class GemmTileEngineTest : public ::testing::TestWithParam { protected: @@ -185,7 +196,16 @@ TEST_P(GemmTileEngineTest, BasicFunctionality) } catch(const std::exception& e) { - FAIL() << "Kernel launch failed: " << e.what(); + std::string error_msg(e.what()); + // If arguments not supported, skip the test (configuration validation failure, not a bug) + if(error_msg.find("Arguments not supported") != std::string::npos) + { + GTEST_SKIP() << "Configuration not supported: " << e.what(); + } + else + { + FAIL() << "Kernel launch failed: " << e.what(); + } } // Copy result back from device @@ -208,13 +228,11 @@ TEST_P(GemmTileEngineTest, KernelInfo) << std::endl; } -// Define test parameters for GEMM verification +// Use config-specific test parameters (included via compile flags) +// CONFIG_TEST_PARAMS is defined in the auto-generated test_params.hpp file INSTANTIATE_TEST_SUITE_P(GemmVerification, GemmTileEngineTest, - ::testing::Values(GemmTestParams{256, 256, 128, 1}, - GemmTestParams{256, 256, 1024, 1}, - GemmTestParams{256, 512, 512, 1}, - GemmTestParams{512, 256, 512, 1}), + ::testing::ValuesIn(CONFIG_TEST_PARAMS), [](const ::testing::TestParamInfo& param_info) { return std::to_string(param_info.param.m) + "x" + std::to_string(param_info.param.n) + "x" + From afe1ff618df6fb28532331560f9b40a0b396a1da Mon Sep 17 00:00:00 2001 From: Michael Mcminn <47832147+UD-mmcminn@users.noreply.github.com> Date: Mon, 3 Nov 2025 10:31:31 -0500 Subject: [PATCH 2/7] Ud fix moe sorting gfx908 (#2720) * Adding a ds permute fallback for the gfx908 and older for row_newbcast:7 instruction * Better macro for selecting ROW_NEWBCAST * clang-format the update --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- .../fused_moe/kernel/moe_sorting_kernel.hpp | 70 ++++++++++++------- 1 file changed, 46 insertions(+), 24 deletions(-) diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 2918cd33bc..f6189c7495 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -10,6 +10,26 @@ #include #include +#if !defined(CK_TILE_HAS_ROW_NEWBCAST) +// row_newbcast (DPP modifier 0x157) support by architecture: +// - Not supported: gfx908 (MI100) and older +// - Supported: gfx90a (MI200), gfx94x (MI300), and all RDNA architectures + +#if defined(__HIP_DEVICE_COMPILE__) && defined(__HIP_PLATFORM_AMD__) +#if defined(__gfx908__) || defined(__gfx906__) || defined(__gfx900__) +// Explicitly disable for known unsupported architectures +#define CK_TILE_HAS_ROW_NEWBCAST 0 +#else +// Assume support for gfx90a and newer (including all gfx94x and RDNA) +// This is safer as new architectures typically maintain backward compatibility +#define CK_TILE_HAS_ROW_NEWBCAST 1 +#endif +#else +// Conservative default for non-AMD or host compilation +#define CK_TILE_HAS_ROW_NEWBCAST 0 +#endif +#endif + namespace ck_tile { #define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \ @@ -380,18 +400,7 @@ struct MoeSortingKernel row_mask, bank_mask, bound_ctrl))); // row_shr:8 -#if 0 - constexpr int bank_mask_0_7 = 0b1100; - auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; }; - thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t, - __builtin_amdgcn_update_dpp(0, /* old value */ - __builtin_bit_cast(int, thread_data), - 0x157, - row_mask, - bank_mask_0_7, - bound_ctrl))// row_newbcast:7 - ); -#else +#if CK_TILE_HAS_ROW_NEWBCAST data_t xxx =__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), 0x157, @@ -401,6 +410,17 @@ struct MoeSortingKernel data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx; thread_data = thread_data - yyy; +#else + // portable fallback for gfx908 and older: emulate row_newbcast:7 via ds_bpermute + // For wave_size == 8 context, we need to broadcast from lane 7 of the 16-lane group + int broadcast_src_lane = (__lane_id() & ~15) + 7; // Lane 7 of the 16-lane group + int broadcast_addr = broadcast_src_lane << 2; // Convert to byte address + int bcast7 = __builtin_amdgcn_ds_bpermute(broadcast_addr, __builtin_bit_cast(int, thread_data)); + + // Apply subtraction only to odd 8-lane groups (lanes 8-15 of each 16-lane unit) + if ((__lane_id() / 8) % 2 != 0) { // Note: != 0, not == 0 + thread_data = thread_data - __builtin_bit_cast(data_t, bcast7); + } #endif } @@ -1267,18 +1287,7 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data) row_mask, bank_mask, bound_ctrl))); // row_shr:8 -#if 0 - constexpr int bank_mask_0_7 = 0b1100; - auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; }; - thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t, - __builtin_amdgcn_update_dpp(0, /* old value */ - __builtin_bit_cast(int, thread_data), - 0x157, - row_mask, - bank_mask_0_7, - bound_ctrl))// row_newbcast:7 - ); -#else +#if CK_TILE_HAS_ROW_NEWBCAST data_t xxx = __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), @@ -1289,6 +1298,19 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data) data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx; thread_data = thread_data - yyy; +#else + // portable fallback for gfx908 and older: emulate row_newbcast:7 via ds_bpermute + // For wave_size == 8 context, we need to broadcast from lane 7 of the 16-lane group + int broadcast_src_lane = (__lane_id() & ~15) + 7; // Lane 7 of the 16-lane group + int broadcast_addr = broadcast_src_lane << 2; // Convert to byte address + int bcast7 = + __builtin_amdgcn_ds_bpermute(broadcast_addr, __builtin_bit_cast(int, thread_data)); + + // Apply subtraction only to odd 8-lane groups (lanes 8-15 of each 16-lane unit) + if((__lane_id() / 8) % 2 != 0) + { // Note: != 0, not == 0 + thread_data = thread_data - __builtin_bit_cast(data_t, bcast7); + } #endif } if constexpr(wave_size > 8) From 2ec57a8e704f55b545877f6e4f545ebda4a21833 Mon Sep 17 00:00:00 2001 From: Emily Martins Date: Mon, 27 Oct 2025 17:26:04 +0000 Subject: [PATCH 3/7] Replace CK_TILE_PIPELINE macros with a common enum This change replaces pipeline macros like CK_TILE_PIPELINE_COMPUTE_V3, CK_TILE_PIPELINE_MEMORY, etc in the CK Tile examples with a common enum called GemmPipeline to reduce code duplication. --- example/ck_tile/03_gemm/gemm_basic.cpp | 2 +- .../03_gemm/gemm_splitk_two_stage_reduce.cpp | 2 +- example/ck_tile/03_gemm/gemm_utils.hpp | 89 +++++++++---------- example/ck_tile/03_gemm/universal_gemm.cpp | 6 +- .../ck_tile/16_batched_gemm/batched_gemm.hpp | 36 ++++---- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 67 +++++++------- .../17_grouped_gemm/grouped_gemm_multi_d.hpp | 46 +++++----- .../19_gemm_multi_d/gemm_multi_d_fp16.hpp | 37 ++++---- .../20_grouped_convolution/conv_configs.hpp | 57 ++++++------ .../22_gemm_multi_abd/gemm_multi_abd_fp16.hpp | 41 ++++----- include/ck_tile/ops/gemm.hpp | 1 + .../ops/gemm/pipeline/gemm_pipelines.hpp | 21 +++++ .../gemm/test_gemm_pipeline_smoke_util.hpp | 53 +++++------ 13 files changed, 220 insertions(+), 238 deletions(-) create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index f92f6ef87a..3c26661c84 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -68,7 +68,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) else if(data_type == "pk_int4_t") { // TODO: Add support for bhalf_t ADataType - if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) { return run_gemm_example_prec_type::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) { return run_gemm_example_prec_type, ck_tile::half_t, diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index dbed40800e..6d833fbd7a 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -12,13 +12,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 -#define CK_TILE_PIPELINE_COMPUTE_V5 4 -#define CK_TILE_PIPELINE_COMPUTE_V6 5 -#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6 - template constexpr ck_tile::index_t get_k_warp_tile() { @@ -69,7 +62,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; static constexpr bool TiledMMAPermuteN = false; @@ -91,9 +84,9 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template @@ -111,8 +104,8 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; }; template @@ -131,8 +124,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -150,8 +143,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -169,8 +162,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -190,8 +183,8 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -213,8 +206,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -232,8 +225,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -252,7 +245,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; static constexpr ck_tile::index_t NumWaveGroups = 2; }; @@ -272,7 +265,7 @@ struct GemmConfigComputeV6 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = 16; static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V6; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V6; static constexpr ck_tile::index_t NumWaveGroups = 1; }; @@ -291,13 +284,13 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); - static constexpr int kBlockPerCu = 1; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = true; - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; + static constexpr int kBlockPerCu = 1; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; template @@ -315,13 +308,13 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); - static constexpr int kBlockPerCu = 2; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = true; - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; template @@ -465,11 +458,11 @@ struct DataTypeTraits static constexpr const char* name = "int8"; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -478,7 +471,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -487,7 +480,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; @@ -496,7 +489,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; @@ -505,7 +498,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6; @@ -514,7 +507,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index f9a7263a5f..a8a7288a3d 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -58,7 +58,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) else if(data_type == "fp16i4") { // TODO: Add support for bhalf_t ADataType - if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) { return run_gemm_example_prec_type, Invoker, @@ -73,7 +73,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) } else if(data_type == "fp8i4") { - if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) { return run_gemm_example_prec_type, Invoker, @@ -88,7 +88,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) } else if(data_type == "bf8i4") { - if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) { return run_gemm_example_prec_type, Invoker, diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.hpp b/example/ck_tile/16_batched_gemm/batched_gemm.hpp index 33da0bf0a5..c0935a0e46 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.hpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.hpp @@ -11,10 +11,6 @@ #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/utility/json_dump.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 - struct GemmConfigMemory { // Memory friendly for Interwave scheduler @@ -30,9 +26,9 @@ struct GemmConfigMemory static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 8; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; struct GemmConfigV3 @@ -50,9 +46,9 @@ struct GemmConfigV3 static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV4 @@ -71,9 +67,9 @@ struct GemmConfigV4 static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV3_Wmma @@ -91,16 +87,16 @@ struct GemmConfigV3_Wmma static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -109,7 +105,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -118,7 +114,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 57d3f224d8..049957cbfd 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -11,11 +11,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 -#define CK_TILE_PIPELINE_PRESHUFFLE_V2 4 - template constexpr ck_tile::index_t get_k_warp_tile() { @@ -87,7 +82,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; static constexpr bool Persistent = true; @@ -109,8 +104,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 1; }; @@ -132,8 +127,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; static constexpr int kBlockPerCu = 2; }; @@ -155,8 +150,8 @@ struct GemmConfigComputeV4_V2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; static constexpr int kBlockPerCu = 2; }; @@ -178,12 +173,12 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr bool kPadK = true; - static constexpr int kBlockPerCu = 1; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; - static constexpr bool Preshuffle = true; - static constexpr bool Persistent = true; - static constexpr bool DoubleSmemBuffer = true; + static constexpr int kBlockPerCu = 1; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool Persistent = true; + static constexpr bool DoubleSmemBuffer = true; }; template @@ -201,12 +196,12 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); - static constexpr int kBlockPerCu = 2; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = true; - static constexpr bool kPadK = true; + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr bool kPadK = true; }; template @@ -226,8 +221,8 @@ struct GemmConfigComputeV4_Wmma : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; static constexpr int kBlockPerCu = 2; }; @@ -249,18 +244,18 @@ struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase static constexpr bool kPadK = true; - static constexpr int kBlockPerCu = 1; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = true; + static constexpr int kBlockPerCu = 1; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -269,7 +264,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -278,7 +273,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; @@ -287,7 +282,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp index 12d70eecb6..81c0b654e2 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp @@ -11,10 +11,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 - template constexpr ck_tile::index_t get_k_warp_tile() { @@ -44,8 +40,8 @@ struct GemmConfigBase static constexpr int kBlockPerCu = 1; static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr bool Preshuffle = false; // currently preshuffle == true is not supported yet static constexpr bool Persistent = false; // currently persistent == true is not supported yet static constexpr bool DoubleSmemBuffer = @@ -67,10 +63,10 @@ struct GemmConfigMemory : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 8; - static constexpr bool DoubleSmemBuffer = false; - static constexpr bool Persistent = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr bool Persistent = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; struct GemmConfigV3 : public GemmConfigBase @@ -88,10 +84,10 @@ struct GemmConfigV3 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool Persistent = true; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool Persistent = true; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV4 : public GemmConfigBase { @@ -109,10 +105,10 @@ struct GemmConfigV4 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool Persistent = true; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool Persistent = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV3_Wmma : public GemmConfigBase @@ -130,16 +126,16 @@ struct GemmConfigV3_Wmma : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -148,7 +144,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -157,7 +153,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp index a7ae227627..8a621cd4be 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp @@ -7,12 +7,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 - using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; using D0DataType = ck_tile::half_t; @@ -36,9 +33,9 @@ struct GemmConfigMemory static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 8; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; struct GemmConfigV3 @@ -56,9 +53,9 @@ struct GemmConfigV3 static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV4 @@ -77,9 +74,9 @@ struct GemmConfigV4 static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV3_Wmma @@ -97,16 +94,16 @@ struct GemmConfigV3_Wmma static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -115,7 +112,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -124,7 +121,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; diff --git a/example/ck_tile/20_grouped_convolution/conv_configs.hpp b/example/ck_tile/20_grouped_convolution/conv_configs.hpp index 1be6080383..c688215280 100644 --- a/example/ck_tile/20_grouped_convolution/conv_configs.hpp +++ b/example/ck_tile/20_grouped_convolution/conv_configs.hpp @@ -12,11 +12,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 -#define CK_TILE_PIPELINE_COMPUTE_V5 4 - struct ConvConfigBase { static constexpr bool kPadM = true; @@ -37,7 +32,7 @@ struct ConvConfigBase static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; static constexpr bool TiledMMAPermuteN = false; @@ -61,9 +56,9 @@ struct ConvConfigMemoryInterwave : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template @@ -81,8 +76,8 @@ struct ConvConfigMemoryIntrawave : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; }; template @@ -101,8 +96,8 @@ struct ConvConfigComputeV3 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 32; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -120,8 +115,8 @@ struct ConvConfigComputeV3_1 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -139,8 +134,8 @@ struct ConvConfigComputeV3_2 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 32; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -160,8 +155,8 @@ struct ConvConfigComputeV3_WMMA : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -183,8 +178,8 @@ struct ConvConfigComputeV4 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -202,8 +197,8 @@ struct ConvConfigComputeV4_1 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -222,7 +217,7 @@ struct ConvConfigComputeV5 : public ConvConfigBase static constexpr ck_tile::index_t K_Warp_Tile = 16; static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; }; @@ -245,8 +240,8 @@ struct ConvConfigComputeV3_merged_groups : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 32; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumGroupsToMerge = 2; }; @@ -294,11 +289,11 @@ struct DataTypeTraits static constexpr const char* name = "bf16"; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -307,7 +302,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -316,7 +311,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; @@ -325,7 +320,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; diff --git a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp index 35bc232eca..76a2635e5f 100644 --- a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp +++ b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp @@ -7,16 +7,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 - -#ifndef CK_TILE_PIPELINE_DEFAULT -#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 -#endif - using A0DataType = ck_tile::half_t; using A1DataType = ck_tile::half_t; @@ -49,9 +42,9 @@ struct GemmConfigMemory static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 8; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; struct GemmConfigV3 @@ -69,9 +62,9 @@ struct GemmConfigV3 static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV4 @@ -90,9 +83,9 @@ struct GemmConfigV4 static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV3_Wmma @@ -110,16 +103,16 @@ struct GemmConfigV3_Wmma static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -128,7 +121,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -137,7 +130,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 33be18948b..ec2d2488c8 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -55,6 +55,7 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp new file mode 100644 index 0000000000..9b948626f6 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp @@ -0,0 +1,21 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile { + +enum struct GemmPipeline +{ + COMPUTE_ASYNC, + COMPUTE_V3, + COMPUTE_V4, + COMPUTE_V5, + COMPUTE_V6, + MEMORY, + BASIC_V1, + BASIC_V2, + PRESHUFFLE_V2 +}; + +} // namespace ck_tile diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp index 0820be5b30..1f9033cab9 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp @@ -10,11 +10,6 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 -#define CK_TILE_PIPELINE_COMPUTE_V5 4 - class ArgumentsNotSupportedException : public std::logic_error { public: @@ -56,7 +51,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; }; @@ -76,9 +71,9 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template @@ -96,8 +91,8 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; }; template @@ -116,8 +111,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -135,8 +130,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -154,8 +149,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -177,8 +172,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -196,8 +191,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -216,7 +211,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; }; @@ -235,8 +230,8 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -401,11 +396,11 @@ struct DataTypeTraits static constexpr const char* name = "int8"; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -414,7 +409,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -423,7 +418,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; @@ -432,7 +427,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; From 057b7d43b4f1edd4bc6e881403588af8c8e96fd4 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Mon, 3 Nov 2025 09:37:35 -0800 Subject: [PATCH 4/7] fix the compv4 and async pipeline when tile handler is 1 (#3141) --- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 17 ++++++++++++++++- .../pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp | 17 ++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 1d2a3e180b..91da3cd27b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -25,6 +25,10 @@ struct BaseGemmPipelineAgBgCrCompAsync CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) { + if(num_loop == 1) + { + return TailNumber::One; + } if(num_loop % PrefetchStages == 1) { return TailNumber::Three; @@ -65,6 +69,11 @@ struct BaseGemmPipelineAgBgCrCompAsync return run_func(bool_constant{}, integral_constant{}); } + else + { + return (run_func(bool_constant{}, + integral_constant{})); + } } // If execution reaches here, it's an invalid tail_number because it wasn't handled above. #if defined(__HIP_DEVICE_COMPILE__) @@ -485,7 +494,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}, integral_constant{}); } + else + { + return (run_func(bool_constant{}, + integral_constant{})); + } } // If execution reaches here, it's an invalid tail_number because it wasn't handled above. #if defined(__HIP_DEVICE_COMPILE__) @@ -621,7 +630,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 __builtin_amdgcn_sched_barrier(0); } } - else + else if(TailNum == TailNumber::Two) { // 2 { @@ -641,6 +650,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 __builtin_amdgcn_sched_barrier(0); } } + else if(TailNum == TailNumber::One) + { + block_sync_lds(); + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + __builtin_amdgcn_sched_barrier(0); + } return c_block_tile; } }; From 507d81c3af51b81f15b946a2a4bef7f594620292 Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Mon, 3 Nov 2025 20:59:01 +0100 Subject: [PATCH 5/7] Fix splitk preshuffle (#3137) * Fix splitK multiply_multiply_wp * Add tests for gemm_multiply_multiply_wp * Add tests for gemm_universal_preshuffle (KBatch = 1) * Add tests gemm_blockscale_wp * Fix splitk gemm universal preshuffle * Run new tests on arch supporting fp8 * Restore example * Fix strides profiler * Fix tests * Fix clang format * Finalize profiler preshuffle with tolerances * Minor improvements to splitk related changes * Address review comments: clang format and ckProfiler typo * Remove b_k_split_offset from SplitKBatchOffset struct --- ...vice_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 5 + ...wise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 77 +++++++------ ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 67 +++++++----- profiler/include/profiler/common.hpp | 103 ++++++++++++++++++ .../profile_gemm_blockscale_wp_impl.hpp | 46 +++++--- ...profile_gemm_multiply_multiply_wp_impl.hpp | 29 ++++- ...profile_gemm_universal_preshuffle_impl.hpp | 25 ++++- ...ile_grouped_conv_fwd_outelementop_impl.hpp | 83 +------------- profiler/src/profile_gemm_blockscale_wp.cpp | 26 ++--- test/CMakeLists.txt | 3 + test/gemm_blockscale_wp/CMakeLists.txt | 6 + .../test_gemm_blockscale_wp_xdl_fp8.cpp | 64 +++++++++++ test/gemm_blockscale_wp/test_gemm_common.hpp | 77 +++++++++++++ test/gemm_multiply_multiply_wp/CMakeLists.txt | 6 + .../test_gemm_common.hpp | 93 ++++++++++++++++ ...test_gemm_multiply_multiply_wp_xdl_fp8.cpp | 77 +++++++++++++ test/gemm_universal_preshuffle/CMakeLists.txt | 6 + .../test_gemm_common.hpp | 79 ++++++++++++++ ...test_gemm_universal_preshuffle_xdl_fp8.cpp | 77 +++++++++++++ 19 files changed, 777 insertions(+), 172 deletions(-) create mode 100644 profiler/include/profiler/common.hpp create mode 100644 test/gemm_blockscale_wp/CMakeLists.txt create mode 100644 test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp create mode 100644 test/gemm_blockscale_wp/test_gemm_common.hpp create mode 100644 test/gemm_multiply_multiply_wp/CMakeLists.txt create mode 100644 test/gemm_multiply_multiply_wp/test_gemm_common.hpp create mode 100644 test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp create mode 100644 test/gemm_universal_preshuffle/CMakeLists.txt create mode 100644 test/gemm_universal_preshuffle/test_gemm_common.hpp create mode 100644 test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index ebd168a7d0..ea4e6de6fd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -425,6 +425,11 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle 0) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index 78546c4f99..6ce2f63e3a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -40,14 +40,22 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Full K needed for matrix B + const index_t Kt = karg.K; + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K); + const index_t k_id = blockIdx.z * num_k_per_block; + GridwiseGemm::template Run( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_grid, karg.p_c_grid + splitk_batch_offset.c_reduce_offset, p_shared, - karg); + karg, + k_id, + Kt); } #else ignore = karg; @@ -74,15 +82,23 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Full K needed for matrix B + const index_t Kt = karg.K; + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K); + const index_t k_id = blockIdx.z * num_k_per_block; + GridwiseGemm::template Run_2Lds( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_grid, karg.p_c_grid + splitk_batch_offset.c_reduce_offset, p_shared_0, p_shared_1, - karg); + karg, + k_id, + Kt); } #else ignore = karg; @@ -658,25 +674,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; } - if constexpr(is_same_v) - { - b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; - } - else if constexpr(is_same_v) - { - if constexpr(!PermuteB) - { - // b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; - - b_k_split_offset = blockIdx.z * karg.KRead * NLane / BPackedSize; - } - else - { - const int k0_offset = karg.KRead * karg.N; - b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; - } - } - if(blockIdx.z < static_cast(karg.KBatch - 1)) { karg.K = karg.KRead; @@ -697,7 +694,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle } index_t a_k_split_offset; - index_t b_k_split_offset; index_t c_reduce_offset; }; @@ -900,6 +896,11 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0) + { + return false; + } + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || @@ -1134,7 +1135,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock) + c_grid_desc_mblock_mperblock_nblock_nperblock, + const index_t k_id) { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -1226,7 +1228,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle true>(b_grid_desc_bpreshuffled, make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, - 0, + k_id, KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment @@ -1465,10 +1467,12 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle const BDataType* p_b_grid, CDataType* p_c_grid, void* p_shared, - const Problem& problem) + const Problem& problem, + const index_t k_id, + const index_t Kt) { index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); - index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); + index_t BK0Shuffled = CalculateBK0Shuffled(Kt); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = @@ -1491,7 +1495,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle problem, a_grid_desc_ak0_m_ak1, b_grid_desc_bpreshuffled, - c_grid_desc_mblock_mperblock_nblock_nperblock); + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id); } template ( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -1606,7 +1612,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle true>(b_grid_desc_bpreshuffled, make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, - 0, + k_id, KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment @@ -1849,10 +1855,12 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle CDataType* p_c_grid, void* p_shared_0, void* p_shared_1, - const Problem& problem) + const Problem& problem, + const index_t k_id, + const index_t Kt) { index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); - index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); + index_t BK0Shuffled = CalculateBK0Shuffled(Kt); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = @@ -1877,7 +1885,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle problem, a_grid_desc_ak0_m_ak1, b_grid_desc_bpreshuffled, - c_grid_desc_mblock_mperblock_nblock_nperblock); + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 2e95ec0d52..f2f1530599 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -43,18 +43,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Full K needed for matrix B + const index_t Kt = karg.K; + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K); + const index_t k_id = blockIdx.z * num_k_per_block; + GridwiseGemm::template Run( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_grid, karg.p_ds_grid, karg.p_c_grid, p_shared, karg, karg.a_element_op, karg.b_element_op, - karg.c_element_op); + karg.c_element_op, + k_id, + Kt); } #else ignore = karg; @@ -79,11 +87,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Full K needed for matrix B + const index_t Kt = karg.K; + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K); + const index_t k_id = blockIdx.z * num_k_per_block; + GridwiseGemm::template Run_2Lds( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_grid, karg.p_ds_grid, karg.p_c_grid, p_shared, @@ -91,7 +105,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg, karg.a_element_op, karg.b_element_op, - karg.c_element_op); + karg.c_element_op, + k_id, + Kt); } #else ignore = karg; @@ -691,16 +707,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle a_k_split_offset = k_id * karg.KRead * karg.StrideA; } - if constexpr(is_same_v) - { - b_k_split_offset = k_id * karg.KRead * karg.StrideB; - } - else if constexpr(is_same_v) - { - // KPack * NLane * KLane * K0 * N0 - b_k_split_offset = k_id * karg.KRead * NLane; - } - if(k_id < karg.KBatch - 1) { karg.K = karg.KRead; @@ -712,7 +718,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle } index_t a_k_split_offset; - index_t b_k_split_offset; }; __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -1163,7 +1168,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + const index_t k_id, + const index_t Kt) { const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; Run( @@ -1176,7 +1183,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle a_element_op, b_element_op, c_element_op, - block_2_ctile_map); + block_2_ctile_map, + k_id, + Kt); } template (b_grid_desc_bpreshuffled, make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, - 0, + k_id, KPackPerGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment @@ -1597,7 +1608,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + const index_t k_id, + const index_t Kt) { const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; Run_2Lds( @@ -1611,7 +1624,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle a_element_op, b_element_op, c_element_op, - block_2_ctile_map); + block_2_ctile_map, + k_id, + Kt); } template (b_grid_desc_bpreshuffled, make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, - 0, + k_id, KPackPerGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment diff --git a/profiler/include/profiler/common.hpp b/profiler/include/profiler/common.hpp new file mode 100644 index 0000000000..2f72e67c6b --- /dev/null +++ b/profiler/include/profiler/common.hpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace profiler { + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp index 0921b48842..da0dc60760 100644 --- a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp +++ b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp @@ -69,19 +69,19 @@ template -bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideE, - int n_warmup, - int n_iter, - uint64_t rotating = 0) +bool profile_gemm_blockscale_weightpreshuffle_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideE, + int n_warmup, + int n_iter, + uint64_t rotating = 0) { bool pass = true; @@ -126,6 +126,26 @@ bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification, Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a0_m_k, ALayout{}, StrideA); + StrideB = get_stride(b0_k_n, BLayout{}, StrideB); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + int total_gemm_needed = a0_m_k.GetElementSpaceSizeInBytes() + b0_k_n.GetElementSpaceSizeInBytes() + a1_m_k.GetElementSpaceSizeInBytes() + b1_k_n.GetElementSpaceSizeInBytes(); diff --git a/profiler/include/profiler/profile_gemm_multiply_multiply_wp_impl.hpp b/profiler/include/profiler/profile_gemm_multiply_multiply_wp_impl.hpp index c76387e2b0..21613e49c6 100644 --- a/profiler/include/profiler/profile_gemm_multiply_multiply_wp_impl.hpp +++ b/profiler/include/profiler/profile_gemm_multiply_multiply_wp_impl.hpp @@ -20,6 +20,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "profiler/common.hpp" namespace ck { namespace profiler { @@ -112,6 +113,28 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a_m_k, ALayout{}, StrideA); + StrideB = get_stride(b_k_n, BLayout{}, StrideB); + StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD0); + StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD1); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + int total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() + d0_m_n.GetElementSpaceSizeInBytes() + d1_m_n.GetElementSpaceSizeInBytes(); @@ -133,7 +156,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, case 1: a_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); b_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); - d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-1, 1}); d1_m_n.GenerateTensorValue(GeneratorTensor_2{-1, 1}); break; default: @@ -282,8 +305,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, is_same_v)) { std::string msg = "Error: Incorrect results!"; - double rtol = 1e-3; - double atol = 5e-2; + double rtol = get_rtol(); + double atol = get_atol(); pass = pass & ck::utils::check_err( e_m_n_device_result, e_m_n_host_result, msg, rtol, atol); } diff --git a/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp b/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp index e537cf2770..5ec056efd1 100644 --- a/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp @@ -20,6 +20,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "profiler/common.hpp" namespace ck { namespace profiler { @@ -99,6 +100,26 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification, Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a_m_k, ALayout{}, StrideA); + StrideB = get_stride(b_k_n, BLayout{}, StrideB); + StrideC = get_stride(c_m_n_host_result, CLayout{}, StrideC); + std::size_t total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes(); int rotating_count = std::max( @@ -317,8 +338,8 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification, is_same_v) { std::string msg = "Error: Incorrect results!"; - double rtol = 1e-1; - double atol = 1e-1; + double rtol = get_rtol(); + double atol = get_atol(); pass = pass & ck::utils::check_err( c_m_n_device_result, c_m_n_host_result, msg, rtol, atol); } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp index b553e07735..ae12070014 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp @@ -5,92 +5,11 @@ #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor_generator.hpp" +#include "profiler/common.hpp" namespace ck { namespace profiler { -template -inline constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 1.5e-1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - -template -inline constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 16.1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 8192.1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - template ? N : K; const int DefaultStrideE = ck::is_same_v ? N : M; - bool pass = ck::profiler::profile_gemm_blockscale_weighpreshuffle_impl( + bool pass = ck::profiler::profile_gemm_blockscale_weightpreshuffle_impl( do_verification, init_method, do_log, diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 810ae8d231..d47e55db64 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -245,10 +245,13 @@ add_subdirectory(conv_util) add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) add_subdirectory(gemm_add) +add_subdirectory(gemm_blockscale_wp) add_subdirectory(gemm_layernorm) add_subdirectory(gemm_multi_abd) +add_subdirectory(gemm_multiply_multiply_wp) add_subdirectory(gemm_split_k) add_subdirectory(gemm_universal) +add_subdirectory(gemm_universal_preshuffle) add_subdirectory(gemm_b_scale) add_subdirectory(gemm_universal_streamk) add_subdirectory(gemm_reduce) diff --git a/test/gemm_blockscale_wp/CMakeLists.txt b/test/gemm_blockscale_wp/CMakeLists.txt new file mode 100644 index 0000000000..d198db0870 --- /dev/null +++ b/test/gemm_blockscale_wp/CMakeLists.txt @@ -0,0 +1,6 @@ +if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") + add_gtest_executable(test_gemm_blockscale_wp_xdl_fp8 test_gemm_blockscale_wp_xdl_fp8.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_blockscale_wp_xdl_fp8 PRIVATE utility device_gemm_blockscale_wp_instance) + endif() +endif() diff --git a/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp b/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp new file mode 100644 index 0000000000..5d88e04690 --- /dev/null +++ b/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_common.hpp" + +using F8 = ck::f8_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmBlockScaleWP_FP8_MK_NK : public ck::test::TestGemmBlockscaleWPCommon< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_NK = ::testing::Types< +#if defined(CK_ENABLE_FP8) + std::tuple< F8, F32, F8, F32, F8, BF16> +#endif + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmBlockScaleWP_FP8_MK_NK, KernelTypes_MK_NK); + +TYPED_TEST(TestGemmBlockScaleWP_FP8_MK_NK, Regular0) +{ + std::vector Ms{128, 256, 512}; + constexpr int N = 512; + constexpr int K = 2048; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmBlockScaleWP_FP8_MK_NK, Regular1) +{ + std::vector Ms{128, 256, 512}; + constexpr int N = 1024; + constexpr int K = 4096; + + for(int M : Ms) + this->Run(M, N, K); +} diff --git a/test/gemm_blockscale_wp/test_gemm_common.hpp b/test/gemm_blockscale_wp/test_gemm_common.hpp new file mode 100644 index 0000000000..25ed67a737 --- /dev/null +++ b/test/gemm_blockscale_wp/test_gemm_common.hpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_blockscale_wp_impl.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using F32 = float; + +template +class TestGemmBlockscaleWPCommon : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using A0DataType = std::tuple_element_t<2, Tuple>; + using A1DataType = std::tuple_element_t<3, Tuple>; + using B0DataType = std::tuple_element_t<4, Tuple>; + using B1DataType = std::tuple_element_t<5, Tuple>; + using ComputeDataType = std::tuple_element_t<6, Tuple>; + using CDataType = std::tuple_element_t<7, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; + static constexpr bool log_ = false; + static constexpr bool bench_ = false; + static constexpr index_t ScaleBlockM = 1; + static constexpr index_t ScaleBlockN = 128; + static constexpr index_t ScaleBlockK = 128; + + void Run(const int M, const int N, const int K, int n_warmup = 1, int n_iter = 10) + { + bool all_success = true; + + int StrideA = std::is_same_v ? K : M; + int StrideB = std::is_same_v ? N : K; + int StrideC = std::is_same_v ? N : M; + + all_success = + all_success & + ck::profiler::profile_gemm_blockscale_weightpreshuffle_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideC, + n_warmup, + n_iter); + + EXPECT_TRUE(all_success); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_multiply_multiply_wp/CMakeLists.txt b/test/gemm_multiply_multiply_wp/CMakeLists.txt new file mode 100644 index 0000000000..4302084a6f --- /dev/null +++ b/test/gemm_multiply_multiply_wp/CMakeLists.txt @@ -0,0 +1,6 @@ +if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") + add_gtest_executable(test_gemm_multiply_multiply_wp_xdl_fp8 test_gemm_multiply_multiply_wp_xdl_fp8.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_multiply_multiply_wp_xdl_fp8 PRIVATE utility device_gemm_multiply_multiply_wp_instance) + endif() +endif() diff --git a/test/gemm_multiply_multiply_wp/test_gemm_common.hpp b/test/gemm_multiply_multiply_wp/test_gemm_common.hpp new file mode 100644 index 0000000000..37e2b353e6 --- /dev/null +++ b/test/gemm_multiply_multiply_wp/test_gemm_common.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_multiply_multiply_wp_impl.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using F32 = float; + +template +class TestGemmMultiplyMultiplyWPCommon : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using D0Layout = std::tuple_element_t<2, Tuple>; + using D1Layout = std::tuple_element_t<3, Tuple>; + using ELayout = Row; + using ADataType = std::tuple_element_t<4, Tuple>; + using BDataType = std::tuple_element_t<5, Tuple>; + using ComputeDataType = std::tuple_element_t<6, Tuple>; + using D0DataType = std::tuple_element_t<7, Tuple>; + using D1DataType = std::tuple_element_t<8, Tuple>; + using EDataType = std::tuple_element_t<9, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // decimal value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1, 2, 4}; } + + void Run(const int M, const int N, const int K) + { + for(size_t i = 0; i < k_batches_.size(); i++) + { + RunSingle(M, N, K, k_batches_[i]); + } + } + + void RunSingle( + const int M, const int N, const int K, int kbatch = 1, int n_warmup = 1, int n_iter = 10) + { + bool all_success = true; + + int StrideA = std::is_same_v, Row> ? K : M; + int StrideB = std::is_same_v, Row> ? N : K; + int StrideD0 = std::is_same_v, Row> ? N : M; + int StrideD1 = std::is_same_v, Row> ? N : M; + int StrideE = std::is_same_v ? N : M; + + all_success = + all_success & + ck::profiler::profile_gemm_multiply_multiply_weight_preshuffle_impl( + verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideD0, + StrideD1, + StrideE, + kbatch, + n_warmup, + n_iter); + + EXPECT_TRUE(all_success); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp b/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp new file mode 100644 index 0000000000..bf9b909628 --- /dev/null +++ b/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_common.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmMultiplyMultiplyWP_FP8_MK_NK + : public ck::test::TestGemmMultiplyMultiplyWPCommon< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_NK = ::testing::Types< +#if defined(CK_ENABLE_FP8) + std::tuple< F8, F8, F8, F32, F32, F16>, + std::tuple< F8, F8, F8, F32, F32, BF16> +#endif + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmMultiplyMultiplyWP_FP8_MK_NK, KernelTypes_MK_NK); + +TYPED_TEST(TestGemmMultiplyMultiplyWP_FP8_MK_NK, Regular0) +{ + std::vector Ms{128, 224, 256, 448, 512}; + constexpr int N = 512; + constexpr int K = 2048; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmMultiplyMultiplyWP_FP8_MK_NK, Regular1) +{ + std::vector Ms{128, 224, 256, 448, 512}; + constexpr int N = 1024; + constexpr int K = 4096; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmMultiplyMultiplyWP_FP8_MK_NK, Regular2) +{ + std::vector Ms{128, 256, 512}; + constexpr int N = 448; + constexpr int K = 2048; + + for(int M : Ms) + this->Run(M, N, K); +} diff --git a/test/gemm_universal_preshuffle/CMakeLists.txt b/test/gemm_universal_preshuffle/CMakeLists.txt new file mode 100644 index 0000000000..0d8955f6a4 --- /dev/null +++ b/test/gemm_universal_preshuffle/CMakeLists.txt @@ -0,0 +1,6 @@ +if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") + add_gtest_executable(test_gemm_universal_preshuffle_xdl_fp8 test_gemm_universal_preshuffle_xdl_fp8.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_universal_preshuffle_xdl_fp8 PRIVATE utility device_gemm_universal_preshuffle_instance) + endif() +endif() diff --git a/test/gemm_universal_preshuffle/test_gemm_common.hpp b/test/gemm_universal_preshuffle/test_gemm_common.hpp new file mode 100644 index 0000000000..367c1a9c7e --- /dev/null +++ b/test/gemm_universal_preshuffle/test_gemm_common.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_universal_preshuffle_impl.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using F32 = float; + +template +class TestGemmUniversalPreshuffleCommon : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using ADataType = std::tuple_element_t<2, Tuple>; + using BDataType = std::tuple_element_t<3, Tuple>; + using ComputeDataType = std::tuple_element_t<4, Tuple>; + using CDataType = std::tuple_element_t<5, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; + static constexpr bool log_ = false; + static constexpr bool bench_ = false; + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1, 2, 4}; } + + void Run(const int M, const int N, const int K) + { + for(size_t i = 0; i < k_batches_.size(); i++) + { + RunSingle(M, N, K, k_batches_[i]); + } + } + + void RunSingle( + const int M, const int N, const int K, int kbatch = 1, int n_warmup = 1, int n_iter = 10) + { + bool all_success = true; + + int StrideA = std::is_same_v ? K : M; + int StrideB = std::is_same_v ? N : K; + int StrideC = std::is_same_v ? N : M; + + all_success = all_success & + ck::profiler::profile_gemm_universal_preshuffle_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideC, + kbatch, + n_warmup, + n_iter); + + EXPECT_TRUE(all_success); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp b/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp new file mode 100644 index 0000000000..06dca026ee --- /dev/null +++ b/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_common.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversalPreshuffle_FP8_MK_NK + : public ck::test::TestGemmUniversalPreshuffleCommon< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_NK = ::testing::Types< +#if defined(CK_ENABLE_FP8) + std::tuple< F8, F8, F8, F16>, + std::tuple< F8, F8, F8, BF16> +#endif + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversalPreshuffle_FP8_MK_NK, KernelTypes_MK_NK); + +TYPED_TEST(TestGemmUniversalPreshuffle_FP8_MK_NK, Regular0) +{ + std::vector Ms{128, 224, 256, 448, 512}; + constexpr int N = 512; + constexpr int K = 2048; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmUniversalPreshuffle_FP8_MK_NK, Regular1) +{ + std::vector Ms{128, 224, 256, 448, 512}; + constexpr int N = 1024; + constexpr int K = 4096; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmUniversalPreshuffle_FP8_MK_NK, Regular2) +{ + std::vector Ms{128, 256, 512}; + constexpr int N = 448; + constexpr int K = 2048; + + for(int M : Ms) + this->Run(M, N, K); +} From c7ded76cc784f0b4d2c24d3985cb587ad22cbd7f Mon Sep 17 00:00:00 2001 From: Vidyasagar Ananthan Date: Mon, 3 Nov 2025 12:21:57 -0800 Subject: [PATCH 6/7] Adding note on CMake convenience script (#3139) * Adding note on convenience script * Addressing feedback * Update README.md reword --------- Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> --- README.md | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 32688b6574..01d523c2ab 100644 --- a/README.md +++ b/README.md @@ -93,13 +93,44 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa want to build the library for a list of different architectures, you should use the `GPU_ARCHS` build argument, for example `GPU_ARCHS=gfx908;gfx1030;gfx1100;gfx942`. -4. Build the entire CK library: + **Convenience script for development builds:** + + Alternatively, you can use the provided convenience script `script/cmake-ck-dev.sh` which automatically + configures CK for development with sensible defaults. In the build directory: + + ```bash + ../script/cmake-ck-dev.sh + ``` + + This script: + * Cleans CMake cache files before configuring + * Sets `BUILD_DEV=ON` for development mode + * Defaults to GPU targets: `gfx908;gfx90a;gfx942` + * Enables verbose makefile output + * Sets additional compiler flags for better error messages + + By default, it considers the parent directory to be the project source directory. + + You can specify the source directory as the first argument. + You can specify custom GPU targets (semicolon-separated) as the second argument: + + ```bash + ../script/cmake-ck-dev.sh .. gfx1100 + ``` + + Or pass additional cmake arguments: + + ```bash + ../script/cmake-ck-dev.sh .. gfx90a -DCMAKE_BUILD_TYPE=Release + ``` + +5. Build the entire CK library: ```bash make -j"$(nproc)" ``` -5. Install CK: +6. Install CK: ```bash make -j install From 99f38e4d9bedcf1b09d58653c354f042f8c509ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 4 Nov 2025 00:34:48 +0100 Subject: [PATCH 7/7] [CK TILE] Refactor grouped conv fwd large tensor (#3144) --- .../grouped_convolution_forward_invoker.hpp | 150 ++++++++-------- ...nvolution_forward_large_tensor_invoker.hpp | 167 +++++++++--------- .../grouped_convolution_utils.hpp | 2 + .../grouped_convolution_forward_kernel.hpp | 5 +- .../utils/grouped_convolution_utils.hpp | 4 +- 5 files changed, 161 insertions(+), 167 deletions(-) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp index 89922fc07b..d9a65b9639 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp @@ -112,89 +112,87 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== // Regular Convolution: Simple, no split-image // ===================================================================== - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using Kernel = ck_tile::GroupedConvolutionForwardKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::GroupedConvolutionForwardKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; - }; + return ave_time; + }; // ===================================================================== // Split-K lambda @@ -202,11 +200,11 @@ struct GroupedConvolutionForwardInvoker const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { - Run.template operator()(has_hot_loop_, tail_number_, MemoryOpSet{}); + Run.template operator()(has_hot_loop_, tail_number_, MemoryOpSet{}); } else { - Run.template operator()(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); + Run.template operator()(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp index 6a76057d73..2e98c0863b 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp @@ -53,7 +53,10 @@ struct GroupedConvolutionForwardInvoker OutLayout, VectorSizeA, VectorSizeB, - VectorSizeC>; + VectorSizeC, + 1, /*NumGroupsToMerge*/ + ck_tile::element_wise::PassThrough, + true /*EnableSplitImage*/>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< GemmConfig::kPadM, @@ -238,68 +241,64 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== // Kernel launch lambda: Uses EnableSplitImage based on layout support // ===================================================================== - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - // Use split-image kernel if layout supports it, otherwise use regular kernel - using Kernel = ck_tile::GroupedConvolutionForwardKernel; + // Use split-image kernel if layout supports it, otherwise use regular kernel + using Kernel = ck_tile::GroupedConvolutionForwardKernel; - // Create kargs - auto kargs = Kernel::MakeKernelArgs(args); + // Create kargs + auto kargs = Kernel::MakeKernelArgs(args); - // Populate split-image metadata ONLY if using split-image kernel - if constexpr(EnableSplitImage) - { + // Populate split-image metadata ONLY if using split-image kernel kargs.num_spatial_pieces = total_pieces; kargs.split_image.total_d = total_d; kargs.split_image.total_h = total_h; @@ -320,41 +319,35 @@ struct GroupedConvolutionForwardInvoker temp_pieces[i].h_size, temp_pieces[i].w_size}; } - } - // Calculate grid: use total_blocks for split-image, or normal GridSize for regular - const dim3 grids = [&]() { - if constexpr(EnableSplitImage) - return dim3(total_blocks, kargs.GemmBatch, kargs.n_splits); - else - return Kernel::GridSize(kargs); - }(); - const dim3 blocks = Kernel::BlockSize(); + // Calculate grid: use total_blocks for split-image, or normal GridSize for regular + const dim3 grids = dim3(total_blocks, kargs.GemmBatch, kargs.n_splits); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; - }; + return ave_time; + }; // ===================================================================== // Step 4: Dispatch kernel (split-image or regular based on decision) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp index 91fa444f0d..b0e2c02973 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp @@ -11,7 +11,9 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "conv_configs.hpp" + using MemoryOpSet = std::integral_constant; using MemoryOpAtomicAdd = std::integral_constant struct GroupedConvolutionForwardKernel { - static constexpr bool EnableSplitImage = EnableSplitImage_; + static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial; static constexpr ConvolutionSpecialization ConvSpecialization = GroupedConvTraitsType_::ConvSpecialization; diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 8695fecac6..703205fd6e 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -63,7 +63,8 @@ template + typename CDElementwise_ = PassThrough, + bool EnableSplitImage_ = false> struct GroupedConvTraits { private: @@ -74,6 +75,7 @@ struct GroupedConvTraits } public: + static constexpr bool EnableSplitImage = EnableSplitImage_; static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_; static constexpr index_t NDimSpatial = NDimSpatial_; static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;