From 20ef4380d74de7b7970feb5d547a9a643fd530ae Mon Sep 17 00:00:00 2001 From: Thrupti Raj Lakshmana Gowda Date: Mon, 27 Oct 2025 09:15:34 -0500 Subject: [PATCH] Ck tile engine preshuffle (#2919) * Partial Progress : Preshuffle working code for datatype * Partial Progress : Preshuffle Cleanup * Working code for default config with min max step * Partial Progress : PermuteN implemented in validation * Partial Progress : PermuteN changes in Preshuffle * CK Tile Engine Preshuffle Complete * CK TILE ENGINE : Preshuffle Layout validation * CK Tile Engine Preshuffle Validation * Preshuffle Validation check * CK Tile Engine Preshuffle : Fixing Validation Cases * Addressing PR review Comments * Changes in config * Addressing Review Comments * Adding additional architecture in Jenkins * Partial Progress : Selective Datatype and layouts * Limited datatypes and layouts * Addressing CI errors * Datatype updates * Datatype updates * Datatype changes to Preshuffle * Addressing Review Comments * Addressing Review Comments * Datatype changes * Changes to Cmake * Update on Jenkins * Formatting with precommit * Ruff Formatting [ROCm/composable_kernel commit: 8b185e872e1172eba2444ccb49648469598d72e8] --- Jenkinsfile | 10 +- .../configs/simple_test_config.json | 6 +- tile_engine/ops/gemm/CMakeLists.txt | 2 +- .../ops/gemm_preshuffle/CMakeLists.txt | 13 +- .../commons/validation_utils.py | 148 ++++++++++++- .../configs/default_config.json | 81 ++++---- .../configs/user_provided_config.json | 22 +- ...ffle.hpp => gemm_preshuffle_benchmark.hpp} | 8 + ...p => gemm_preshuffle_benchmark_single.cpp} | 15 +- .../gemm_preshuffle_common.hpp | 76 +++---- .../gemm_preshuffle_instance_builder.py | 194 +++++++++++------- .../gemm_preshuffle_profiler.hpp | 27 ++- 12 files changed, 393 insertions(+), 209 deletions(-) rename tile_engine/ops/gemm_preshuffle/{benchmark_gemm_preshuffle.hpp => gemm_preshuffle_benchmark.hpp} (97%) rename tile_engine/ops/gemm_preshuffle/{benchmark_gemm_preshuffle_single.cpp => gemm_preshuffle_benchmark_single.cpp} (89%) diff --git a/Jenkinsfile b/Jenkinsfile index 7a8574df05..b89d6fb657 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1488,7 +1488,7 @@ pipeline { -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ -D GEMM_MULTI_D_DATATYPE="fp16" \ -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ - -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8" \ + -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \ -D GEMM_PRESHUFFLE_LAYOUT="rcr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ ninja -j64 benchmark_gemm_all && \ @@ -1528,7 +1528,7 @@ pipeline { -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ -D GEMM_MULTI_D_DATATYPE="fp16" \ -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ - -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8" \ + -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \ -D GEMM_PRESHUFFLE_LAYOUT="rcr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ ninja -j64 benchmark_gemm_all && \ @@ -1570,11 +1570,7 @@ pipeline { -DCMAKE_CXX_FLAGS=" -O3 " .. && \ ninja -j64 benchmark_gemm_all && \ python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \ - --warmup 5 --repeat 5 --verbose --json results.json && \ - ninja -j64 benchmark_gemm_fp16_rcr && \ - ninja -j64 benchmark_gemm_fp16_rrr && \ - ninja -j64 benchmark_gemm_fp16_crr && \ - ninja -j64 benchmark_gemm_fp16_ccr """ + --warmup 5 --repeat 5 --verbose --json results.json """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) diff --git a/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json b/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json index c80210b963..a4f32a1907 100644 --- a/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json +++ b/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json @@ -1,6 +1,4 @@ { - "problem": { - }, "tile_config": { "tile_m": { "values": [ @@ -85,5 +83,7 @@ false ] } - } + }, + "k_block_per_cu": 1, + "permute_n": false } diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index 0e0ca41c9a..1eb49c0c7f 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -1,5 +1,5 @@ set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)") -set(GEMM_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)") +set(GEMM_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM (semicolon-separated)") set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") option(ENABLE_CCACHE_GEMM "Enable ccache for GEMM ops compilation" OFF) diff --git a/tile_engine/ops/gemm_preshuffle/CMakeLists.txt b/tile_engine/ops/gemm_preshuffle/CMakeLists.txt index 972ad9d0db..d80d2661d1 100644 --- a/tile_engine/ops/gemm_preshuffle/CMakeLists.txt +++ b/tile_engine/ops/gemm_preshuffle/CMakeLists.txt @@ -1,4 +1,4 @@ -set(GEMM_PRESHUFFLE_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM Preshuffle (semicolon-separated)") +set(GEMM_PRESHUFFLE_DATATYPE "fp16;fp8" CACHE STRING "List of datatypes for GEMM Preshuffle (semicolon-separated)") set(GEMM_PRESHUFFLE_LAYOUT "rcr" CACHE STRING "List of layout for GEMM Preshuffle (semicolon-separated)") set(GEMM_PRESHUFFLE_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") option(ENABLE_CCACHE_GEMM_PRESHUFFLE "Enable ccache for GEMM Preshuffle ops compilation" OFF) @@ -65,7 +65,7 @@ function(create_individual_gemm_preshuffle_target datatype layout trait tile_con # Create the executable add_executable(${target_name} EXCLUDE_FROM_ALL - ${GEMM_PRESHUFFLE_SOURCE_DIR}/benchmark_gemm_preshuffle_single.cpp + ${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_benchmark_single.cpp ${instance_header} ) @@ -176,7 +176,7 @@ function(build_individual_gemm_preshuffle_targets datatype layout) OUTPUT_VARIABLE list_output ERROR_VARIABLE list_error ) - + if(NOT ret EQUAL 0) message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}") endif() @@ -273,10 +273,10 @@ else() endforeach() # Create trait-based collection targets - # These are common trait components used across all GEMM kernels - set(GEMM_PRESHUFFLE_PIPELINES "preshufflev1;preshufflev2") + # These are common trait components used across all GEMM Preshuffle kernels + set(GEMM_PRESHUFFLE_PIPELINES "preshufflev2") set(GEMM_PRESHUFFLE_EPILOGUES "default;cshuffle") - set(GEMM_PRESHUFFLE_SCHEDULERS "intrawave;interwave;default") + set(GEMM_PRESHUFFLE_SCHEDULERS "default") foreach(pipeline IN LISTS GEMM_PRESHUFFLE_PIPELINES) add_custom_target(benchmark_gemm_preshuffle_${pipeline}_pipeline) @@ -291,7 +291,6 @@ else() endforeach() # Build individual targets for each datatype/layout combination - foreach(dt IN LISTS GEMM_PRESHUFFLE_DATATYPE) foreach(l IN LISTS GEMM_PRESHUFFLE_LAYOUT) build_individual_gemm_preshuffle_targets(${dt} ${l}) diff --git a/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py b/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py index 454e26a7b5..b38ff5dffb 100644 --- a/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py +++ b/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py @@ -32,7 +32,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ @@ -40,7 +39,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], @@ -52,7 +50,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ @@ -60,7 +57,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], @@ -73,7 +69,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ @@ -81,7 +76,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [ @@ -122,6 +116,12 @@ def element_size(data_type: str) -> float: def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool: """Check if a trait combination is valid.""" + if pipeline not in ["preshufflev2"]: + raise ValueError("Accepted pipeline values are: ['preshufflev2']") + if epilogue not in ["default", "cshuffle"]: + return ValueError("Accepted epilogue values are: ['default', 'cshuffle']") + if scheduler not in ["default"]: + return ValueError("Accepted scheduler values are: ['default']") return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS @@ -173,7 +173,7 @@ def validate_lds_capacity( matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) total_tile_in_lds = matrix_a_size + matrix_b_size - max_tile_size = 2**15 if pipeline == "compv4" else 2**16 + max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16 if total_tile_in_lds > max_tile_size: error_msg = ( @@ -266,6 +266,35 @@ def is_tile_config_valid( if warp_k * warp_tile_k > tile_k: return False + # Validate vector load alignment + m_iter_per_warp = tile_m / (warp_m * warp_tile_m) + vector_valid, vector_error = validate_vector_load_alignment( + warp_tile_m, + warp_tile_k, + a_datatype, + m_iter_per_warp, + wave_size=64, + vector_load_size=16, + ) + if not vector_valid: + logging.debug(f"Vector load alignment failed: {vector_error}") + return False + + # Validate M0, M1, M2 configuration for matrix A row-major layout + m0_m1_m2_valid, m0_m1_m2_error = validate_m0_m1_m2_configuration( + tile_m, + tile_k, + warp_m, + warp_n, + warp_k, + a_datatype, + vector_load_size=16, + warp_size=64, + ) + if not m0_m1_m2_valid: + logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}") + return False + # Validate warp configuration if not validate_warp_configuration(warp_m, warp_n, warp_k): logging.debug( @@ -318,12 +347,117 @@ def is_tile_config_valid( return True +def validate_vector_load_alignment( + wg_m: int, + wg_k: int, + a_datatype: str, + m_iter_per_warp: int, + wave_size: int, + vector_load_size: int, +) -> Tuple[bool, str]: + try: + # Calculate the memory access pattern size + a_element_size = element_size(a_datatype) + access_size = (wg_m * wg_k * a_element_size * m_iter_per_warp) / wave_size + + # Check if it's aligned to vector load size + if access_size % vector_load_size != 0: + error_msg = ( + f"Vector load alignment violation: " + f"({wg_m} * {wg_k} * {a_element_size} * {m_iter_per_warp} / {wave_size}) " + f"% {vector_load_size} = {access_size % vector_load_size} != 0. " + f"Access size: {access_size} bytes" + ) + return False, error_msg + + return True, "" + + except Exception as e: + return False, f"Error in vector load validation: {str(e)}" + + +def validate_m0_m1_m2_configuration( + tile_m: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + a_datatype: str, + vector_load_size: int = 16, + warp_size: int = 64, +) -> Tuple[bool, str]: + """ + Validate M0, M1, M2 configuration for matrix A row-major layout. + This ensures proper memory access pattern alignment. + """ + try: + # Validation for A as row-major + MPerBlock = tile_m + + # Calculate K1 using element size + K1 = vector_load_size / element_size(a_datatype) + + # Check if K1 is valid (must be integer) + if K1 != int(K1): + return ( + False, + f"K1 = {K1} is not an integer. vector_load_size({vector_load_size}) must be divisible by element_size({a_datatype})", + ) + K1 = int(K1) + + # Calculate K0 + if tile_k % K1 != 0: + return False, f"tile_k({tile_k}) must be divisible by K1({K1})" + K0 = tile_k // K1 + + # Calculate M2 + if warp_size % K0 != 0: + return False, f"warp_size({warp_size}) must be divisible by K0({K0})" + M2 = warp_size // K0 + + # Calculate number of warps and block size + NumWarps = warp_m * warp_n * warp_k + BlockSize = NumWarps * warp_size + + # Calculate M0 (assuming get_warp_size() returns warp_size) + M0 = BlockSize // warp_size # This should equal NumWarps + + # Calculate M1 + if (M2 * M0) == 0: + return False, f"M2({M2}) * M0({M0}) cannot be zero" + + if MPerBlock % (M2 * M0) != 0: + return ( + False, + f"MPerBlock({MPerBlock}) must be divisible by M2({M2}) * M0({M0}) = {M2 * M0}", + ) + M1 = MPerBlock // (M2 * M0) + + # Validate the assertion: M0 * M1 * M2 == MPerBlock + calculated_m_per_block = M0 * M1 * M2 + if calculated_m_per_block != MPerBlock: + error_msg = ( + f"Incorrect M0, M1, M2 configuration! " + f"M0({M0}) * M1({M1}) * M2({M2}) = {calculated_m_per_block} != MPerBlock({MPerBlock}). " + f"Configuration: K0={K0}, K1={K1}, NumWarps={NumWarps}, BlockSize={BlockSize}" + ) + return False, error_msg + + return True, "" + + except ZeroDivisionError as e: + return False, f"Division by zero in M0/M1/M2 calculation: {str(e)}" + except Exception as e: + return False, f"Error in M0/M1/M2 validation: {str(e)}" + + # [TODO] Handle this while moving code to commons Add more datatype to this function if needed def get_dtype_string(datatype: str) -> str: """Get C++ type string for datatype""" dtype_map = { "fp16": "ck_tile::fp16_t", "fp8": "ck_tile::fp8_t", + "bf8": "ck_tile::bf8_t", "bf16": "ck_tile::bf16_t", "fp32": "float", "fp64": "double", diff --git a/tile_engine/ops/gemm_preshuffle/configs/default_config.json b/tile_engine/ops/gemm_preshuffle/configs/default_config.json index d4c3537c65..4606cf0c27 100644 --- a/tile_engine/ops/gemm_preshuffle/configs/default_config.json +++ b/tile_engine/ops/gemm_preshuffle/configs/default_config.json @@ -1,62 +1,72 @@ { "tile_config": { "tile_m": { - "values": [ - 128 - ] + "max": 256, + "min": 64, + "step": 64 }, "tile_n": { - "values": [ - 128 - ] + "max": 256, + "min": 64, + "step": 64 }, "tile_k": { - "values": [ - 128 - ] + "max": 256, + "min": 64, + "step": 64 }, "warp_m": { - "values": [ - 1 - ] + "values": [ + 4, + 2, + 1 + ] }, "warp_n": { - "values": [ - 4 - ] + "values": [ + 4, + 2, + 1 + ] }, "warp_k": { - "values": [ - 1 - ] + "values": [ + 1 + ] }, "warp_tile_m": { - "values": [ - 16 - ] + "values": [ + 4, + 16, + 32 + ] }, "warp_tile_n": { - "values": [ - 16 - ] + "values": [ + 16, + 32, + 64 + ] }, "warp_tile_k": { - "values": [ - 16,32 - ] + "values": [ + 8, + 16, + 32, + 64, + 128 + ] } }, "trait_config": { "pipeline": { "values": [ - "preshufflev1", "preshufflev2" ] }, "scheduler": { "values": [ - "interwave", - "intrawave" + "default" ] }, "epilogue": { @@ -81,11 +91,12 @@ ] }, "persistent": { - "values": [ - true, - false - ] + "values": [ + true, + false + ] } }, - "k_block_per_cu": 2 + "k_block_per_cu": 1, + "permute_n": true } \ No newline at end of file diff --git a/tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json b/tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json index c0fc1f6cf8..cf7c79462e 100644 --- a/tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json +++ b/tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json @@ -2,27 +2,27 @@ "tile_config": { "tile_m": { "values": [ - 128 + 64 ] }, "tile_n": { "values": [ - 128 + 64 ] }, "tile_k": { "values": [ - 64 + 192 ] }, "warp_m": { "values": [ - 1 + 2 ] }, "warp_n": { "values": [ - 4 + 2 ] }, "warp_k": { @@ -42,7 +42,7 @@ }, "warp_tile_k": { "values": [ - 16,32 + 32 ] } }, @@ -54,12 +54,13 @@ }, "scheduler": { "values": [ - "intrawave" + "default" ] }, "epilogue": { "values": [ - "default" + "default", + "cshuffle" ] }, "pad_m": { @@ -79,9 +80,10 @@ }, "persistent": { "values": [ - false + true ] } }, - "k_block_per_cu": 8 + "k_block_per_cu": 1, + "permute_n": false } \ No newline at end of file diff --git a/tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp similarity index 97% rename from tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle.hpp rename to tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp index 74fccf6bf2..77a9f26527 100644 --- a/tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp @@ -23,6 +23,14 @@ inline constexpr auto get_metric_name(Metric m) } } +struct KernelConfig +{ + std::tuple tile_dims; + std::tuple warp_dims; + std::tuple warp_tile_dims; + bool permuteN; +}; + struct GemmProblem { int split_k_; diff --git a/tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle_single.cpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp similarity index 89% rename from tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle_single.cpp rename to tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp index 152e27e77e..1f03d1cf9b 100644 --- a/tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle_single.cpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp @@ -75,7 +75,7 @@ inline auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -void benchmark_gemm_preshuffle_single(const ck_tile::ArgParser& arg_parser) +void benchmark_single(const ck_tile::ArgParser& arg_parser) { // Use DataTypeTraits to get the actual type names from the generated header // The generated header defines ADataType, BDataType, AccDataType, CDataType @@ -124,9 +124,16 @@ void benchmark_gemm_preshuffle_single(const ck_tile::ArgParser& arg_parser) try { // Create a lambda that wraps the kernel launch - std::tuple warp_tile_dims = std::make_tuple( SelectedKernel::WarpTileM, SelectedKernel::WarpTileN, SelectedKernel::WarpTileK); + std::tuple tile_dims = + std::make_tuple(SelectedKernel::TileM, SelectedKernel::TileN, SelectedKernel::TileK); + std::tuple warp_dims = std::make_tuple(SelectedKernel::WarpPerBlock_M, + SelectedKernel::WarpPerBlock_N, + SelectedKernel::WarpPerBlock_K); + bool permuteN = SelectedKernel::PermuteN; + + KernelConfig config{tile_dims, warp_dims, warp_tile_dims, permuteN}; auto kernel_func = [](const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) { @@ -134,7 +141,7 @@ void benchmark_gemm_preshuffle_single(const ck_tile::ArgParser& arg_parser) }; // Benchmark the kernel - profiler.benchmark(gemm_problem, kernel_func, warp_tile_dims); + profiler.benchmark(gemm_problem, kernel_func, config); // Select best instance based on metric profiler.select_best_instance(static_cast(arg_parser.get_int("metric"))); @@ -153,7 +160,7 @@ int main(int argc, char* argv[]) if(!result) return EXIT_FAILURE; - benchmark_gemm_preshuffle_single(parser); + benchmark_single(parser); return 0; } catch(const std::exception& e) diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp index 4fb98dc3c2..09ec895ab5 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp @@ -75,58 +75,6 @@ constexpr auto is_row_major(Layout) return ck_tile::bool_constant>{}; } -// // Permutation function for pk_int4_t -// template -// void permute_vectors_i4x4_b(Tensor& tensor) -// { -// const ck_tile::index_t K = tensor.get_length(0); -// const ck_tile::index_t N = tensor.get_length(1); -// // vector pk_i4x4 permute -// for(int i = 0; i < N; i++) -// { -// for(int j = 0; j < K; j += 8) -// { -// int8_t input[8]; - -// for(int k = 0; k < 4; k++) -// { -// int8_t i4x2 = tensor(j + k * 2, i).data; -// input[k * 2 + 0] = (i4x2 >> 4) & 0xf; -// input[k * 2 + 1] = (i4x2 >> 0) & 0xf; -// } - -// // permute 01234567->20643175 -// { -// int8_t hi = input[2]; -// int8_t lo = input[0]; -// int8_t i4x2 = (hi << 4) | lo; -// tensor(j + 0, i) = i4x2; -// } - -// { -// int8_t hi = input[6]; -// int8_t lo = input[4]; -// int8_t i4x2 = (hi << 4) | lo; -// tensor(j + 2, i) = i4x2; -// } - -// { -// int8_t hi = input[3]; -// int8_t lo = input[1]; -// int8_t i4x2 = (hi << 4) | lo; -// tensor(j + 4, i) = i4x2; -// } - -// { -// int8_t hi = input[7]; -// int8_t lo = input[5]; -// int8_t i4x2 = (hi << 4) | lo; -// tensor(j + 6, i) = i4x2; -// } -// } -// } -// } - // Structure to hold kernel traits for dispatcher struct KernelTraits { @@ -211,3 +159,27 @@ auto shuffle_b(const ck_tile::HostTensor& t, std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } + +template +auto shuffle_b_permuteN(const ck_tile::HostTensor& t, + ck_tile::index_t N_Warp_Tile, + ck_tile::index_t K_Warp_Tile, + ck_tile::index_t N_Tile, + ck_tile::index_t N_Warp) +{ + assert(t.get_lengths().size() == 2); + + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + int divisor = N_Warp_Tile == 32 ? 2 : 4; + int NRepeat = N_Tile / N_Warp_Tile / N_Warp; + ck_tile::HostTensor t_view({n_ / N_Tile, + N_Warp, + N_Warp_Tile, + NRepeat, + k_ / K_Warp_Tile, + divisor, + K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); +} diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py index e6e075cb36..1d4b027716 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py @@ -95,67 +95,87 @@ class GemmPreshuffleKernelBuilder: def _get_tile_configs(self, fast_mode=False): """Get tile configurations for the current datatype and layout""" - if "tile_configs" in self.config: - # Old format - return ( - self.config["tile_configs"].get(self.datatype, {}).get(self.layout, []) + + tile_config = self.config["tile_config"] + + # Generate values in the config if default range is given + if tile_config.get("tile_m").get("values") is None: + tile_config.get("tile_m")["values"] = self._generate_values( + tile_config.get("tile_m").get("min"), + tile_config.get("tile_m").get("max"), + tile_config.get("tile_m").get("step"), + ) + if tile_config.get("tile_n").get("values") is None: + tile_config.get("tile_n")["values"] = self._generate_values( + tile_config.get("tile_n").get("min"), + tile_config.get("tile_n").get("max"), + tile_config.get("tile_n").get("step"), + ) + if tile_config.get("tile_k").get("values") is None: + tile_config.get("tile_k")["values"] = self._generate_values( + tile_config.get("tile_k").get("min"), + tile_config.get("tile_k").get("max"), + tile_config.get("tile_k").get("step"), ) - elif "tile_config" in self.config: - # New format - generate combinations from individual parameter values - tile_config = self.config["tile_config"] - # Get all possible values for each parameter - tile_m_values = tile_config.get("tile_m", {}).get("values", [256]) - tile_n_values = tile_config.get("tile_n", {}).get("values", [256]) - tile_k_values = tile_config.get("tile_k", {}).get("values", [32]) - warp_m_values = tile_config.get("warp_m", {}).get("values", [2]) - warp_n_values = tile_config.get("warp_n", {}).get("values", [2]) - warp_k_values = tile_config.get("warp_k", {}).get("values", [1]) - warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [32]) - warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [32]) - warp_tile_k_values = tile_config.get("warp_tile_k", {}).get("values", [32]) + # Get all possible values for each parameter + tile_m_values = tile_config.get("tile_m").get("values") + tile_n_values = tile_config.get("tile_n").get("values") + tile_k_values = tile_config.get("tile_k").get("values") + warp_m_values = tile_config.get("warp_m").get("values") + warp_n_values = tile_config.get("warp_n").get("values") + warp_k_values = tile_config.get("warp_k").get("values") + warp_tile_m_values = tile_config.get("warp_tile_m").get("values") + warp_tile_n_values = tile_config.get("warp_tile_n").get("values") + warp_tile_k_values = tile_config.get("warp_tile_k").get("values") - # Generate all combinations - configs = [] - for tile_m in tile_m_values: - for tile_n in tile_n_values: - for tile_k in tile_k_values: - for warp_m in warp_m_values: - for warp_n in warp_n_values: - for warp_k in warp_k_values: - for warp_tile_m in warp_tile_m_values: - for warp_tile_n in warp_tile_n_values: - for warp_tile_k in warp_tile_k_values: - # Validate configuration - if self._validate_tile_config( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - fast_mode=fast_mode, - ): - configs.append( - { - "tile_m": tile_m, - "tile_n": tile_n, - "tile_k": tile_k, - "warp_m": warp_m, - "warp_n": warp_n, - "warp_k": warp_k, - "warp_tile_m": warp_tile_m, - "warp_tile_n": warp_tile_n, - "warp_tile_k": warp_tile_k, - } - ) - return configs - else: - # Fallback to default - return [] + # Generate all combinations + configs = [] + for tile_m in tile_m_values: + for tile_n in tile_n_values: + for tile_k in tile_k_values: + for warp_m in warp_m_values: + for warp_n in warp_n_values: + for warp_k in warp_k_values: + for warp_tile_m in warp_tile_m_values: + for warp_tile_n in warp_tile_n_values: + for warp_tile_k in warp_tile_k_values: + # Validate configuration + if self._validate_tile_config( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + fast_mode=fast_mode, + ): + configs.append( + { + "tile_m": tile_m, + "tile_n": tile_n, + "tile_k": tile_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "warp_tile_m": warp_tile_m, + "warp_tile_n": warp_tile_n, + "warp_tile_k": warp_tile_k, + } + ) + return configs + + def _generate_values(self, min_val, max_val, step): + """Generate a list of values from min to max with the given step""" + values = [] + val = min_val + while val <= max_val: + values.append(val) + val += step + return values def _generate_trait_combinations(self): """Generate all combinations of traits""" @@ -270,6 +290,12 @@ class GemmPreshuffleKernelBuilder: return True else: + # Validate preshuffle specific constraints + if self.config.get("permute_n"): + valid = (tile_n / warp_tile_n / warp_n) % 2 == 0 + if not valid: + return False + # Full validation for generation # Determine data types for validation a_datatype = self.datatype @@ -299,7 +325,7 @@ class GemmPreshuffleKernelBuilder: ) def _generate_kernel_instance( - self, tile_config, trait_combo, k_block_per_cu, is_header=True + self, tile_config, trait_combo, k_block_per_cu, permute_n, is_header=True ): """Generate a single kernel instance""" ( @@ -349,9 +375,9 @@ class GemmPreshuffleKernelBuilder: acc_type = "float" # Determine output type - c_type = get_dtype_string(self.datatype) + c_type = self.datatype if self.datatype in ["fp8", "bf8"]: - c_type = "ck_tile::fp16_t" + c_type = "fp16" # Determine layouts based on self.layout a_layout, b_layout, c_layout = get_abc_layouts(self.layout) @@ -374,7 +400,7 @@ class GemmPreshuffleKernelBuilder: using ADataType = {get_dtype_string(self.datatype)}; using BDataType = {get_dtype_string(self.datatype)}; using AccDataType = {acc_type}; -using CDataType = {c_type}; +using CDataType = {get_dtype_string(c_type)}; using ALayout = {a_layout}; using BLayout = {b_layout}; @@ -408,6 +434,8 @@ struct SelectedKernel {{ static constexpr bool Preshuffle = true; static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool PermuteN = {"true" if permute_n else "false"}; + // Tile shape using TileShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -485,7 +513,10 @@ struct SelectedKernel {{ WarpTileK, // KPerXdl_ TransposeC, // isCTransposed_ memory_operation, // MemoryOperation_ - NumWaveGroups>; // kNumWaveGroups_ + NumWaveGroups, // kNumWaveGroups_ + false, // FixedVectorSize_ + 1, // VectorSizeC_ + PermuteN>; // isPermuteN_ using GemmEpilogue = ck_tile::CShuffleEpilogue; """ @@ -580,6 +611,7 @@ struct SelectedKernel {{ tile_configs = self._get_tile_configs() trait_combos = self._generate_trait_combinations() k_block_per_cu = self.config.get("k_block_per_cu") + permute_n = self.config.get("permute_n") # Prepare work items for parallel processing work_items = [] @@ -590,6 +622,7 @@ struct SelectedKernel {{ tile_config, trait_combo, k_block_per_cu, + permute_n, self.working_path, self.datatype, self.layout, @@ -681,21 +714,29 @@ struct SelectedKernel {{ def _generate_single_kernel_individual(work_item): """Worker function to generate a single individual kernel file""" - tile_config, trait_combo, k_block_per_cu, working_path, datatype, layout = work_item + ( + tile_config, + trait_combo, + k_block_per_cu, + permute_n, + working_path, + datatype, + layout, + ) = work_item # Create a temporary builder instance for this worker builder = GemmPreshuffleKernelBuilder(working_path, datatype, layout) try: kernel_name, instance_code = builder._generate_kernel_instance( - tile_config, trait_combo, k_block_per_cu + tile_config, trait_combo, k_block_per_cu, permute_n ) - # Create simplified filename without the "gemm_" prefix - # Remove "gemm_" from the beginning of kernel_name for the filename + # Create simplified filename without the "gemm_preshuffle_" prefix + # Remove "gemm_preshuffle_" from the beginning of kernel_name for the filename simplified_name = kernel_name - if simplified_name.startswith("gemm_"): - simplified_name = simplified_name[5:] # Remove "gemm_" prefix + if simplified_name.startswith("gemm_preshuffle_"): + simplified_name = simplified_name[16:] # Remove "gemm_preshuffle_" prefix # Write individual header file header_file = working_path / f"gemm_single_{simplified_name}.hpp" @@ -727,7 +768,7 @@ def main(): parser.add_argument( "--layout", required=True, - choices=["rcr", "rrr", "ccr", "crr"], + choices=["rcr"], help="Matrix layout", ) parser.add_argument("--config_json", required=True, help="Configuration JSON file") @@ -735,7 +776,9 @@ def main(): "--num_workers", type=int, help="Number of parallel workers (default: auto)" ) parser.add_argument( - "--gen_individual", action="store_true", help="Generate individual kernel files" + "--gen_all_individual", + action="store_true", + help="Generate individual kernel files", ) parser.add_argument( "--gen_single", action="store_true", help="Generate a single kernel file" @@ -763,7 +806,7 @@ def main(): assert len(layout_parts) == 3, ( f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)" ) - assert layout_parts[0] == "r" and layout_parts[1] == "c", ( + assert layout_parts[0] in ["r"] and layout_parts[1] in ["c"], ( f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a must be 'r' for row major and matrix_b must be 'c' for column major as it is the only supported layout for preshuffle)" ) assert layout_parts[2] == "r", ( @@ -816,10 +859,11 @@ def main(): ) k_block_per_cu = builder.config.get("k_block_per_cu") + permute_n = builder.config.get("permute_n") # Generate the kernel kernel_name, instance_code = builder._generate_kernel_instance( - tile_config, trait_combo, k_block_per_cu + tile_config, trait_combo, k_block_per_cu, permute_n ) # Write the file @@ -835,13 +879,13 @@ def main(): print(f"Generated {header_file}") - elif args.gen_individual: + elif args.gen_all_individual: # Generate all individual kernel files builder.run(args.num_workers) pass else: parser.error( - "Must specify one of: --list_kernels, --gen_individual, or --gen_single" + "Must specify one of: --list_kernels, --gen_all_individual, or --gen_single" ) diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp index 4f2a929ba0..7d212c934c 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp @@ -2,7 +2,7 @@ #include "ck_tile/host/device_prop.hpp" #include "ck_tile/ops/gemm.hpp" -#include "benchmark_gemm_preshuffle.hpp" +#include "gemm_preshuffle_benchmark.hpp" class GemmProfiler { @@ -17,7 +17,7 @@ class GemmProfiler void benchmark(GemmProblem& gemm_problem, std::function kernel_func, - const std::tuple& warp_tile_dims) + KernelConfig& config) { // Create a vector with a single callable that returns both name and time std::vector(ck_tile::GemmHostArgs&, @@ -30,13 +30,13 @@ class GemmProfiler return std::make_tuple(std::string(KERNEL_NAME), time); }); - benchmark(gemm_problem, callables, warp_tile_dims); + benchmark(gemm_problem, callables, config); } void benchmark(GemmProblem& gemm_problem, std::vector( ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables, - const std::tuple& warp_tile_dims) + KernelConfig& config) { const ALayout layout_a = ALayout{}; const BLayout layout_b = BLayout{}; @@ -110,11 +110,22 @@ class GemmProfiler for(const auto& callable : callables) { - ck_tile::index_t N_Warp_Tile = std::get<1>(warp_tile_dims); - ck_tile::index_t K_Warp_Tile = std::get<2>(warp_tile_dims); + ck_tile::index_t N_Warp_Tile = std::get<1>(config.warp_tile_dims); + ck_tile::index_t K_Warp_Tile = std::get<2>(config.warp_tile_dims); + ck_tile::index_t N_Tile = std::get<1>(config.tile_dims); + ck_tile::index_t N_Warp = std::get<1>(config.warp_dims); + + ck_tile::HostTensor b_shuffle_host = [&]() { + if(config.permuteN) + { + return shuffle_b_permuteN(b_k_n, N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp); + } + else + { + return shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile); + } + }(); - ck_tile::HostTensor b_shuffle_host = - shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile); b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); ck_tile::GemmHostArgs gemm_args = {