From d1c71e62836bce00fc44c74f7450c7350003fff0 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Mon, 27 Oct 2025 15:13:31 +0000 Subject: [PATCH] Merge commit '06973b1cf4987b5f2e7fc1fe504b56df58edaf1f' into develop --- Jenkinsfile | 50 ++++- .../profiler/profile_gemm_multi_abd_impl.hpp | 80 +++----- .../configs/simple_test_config.json | 6 +- tile_engine/ops/gemm/CMakeLists.txt | 2 +- .../ops/gemm_preshuffle/CMakeLists.txt | 13 +- .../commons/validation_utils.py | 148 ++++++++++++- .../configs/default_config.json | 81 ++++---- .../configs/user_provided_config.json | 22 +- ...ffle.hpp => gemm_preshuffle_benchmark.hpp} | 8 + ...p => gemm_preshuffle_benchmark_single.cpp} | 15 +- .../gemm_preshuffle_common.hpp | 76 +++---- .../gemm_preshuffle_instance_builder.py | 194 +++++++++++------- .../gemm_preshuffle_profiler.hpp | 27 ++- 13 files changed, 461 insertions(+), 261 deletions(-) rename tile_engine/ops/gemm_preshuffle/{benchmark_gemm_preshuffle.hpp => gemm_preshuffle_benchmark.hpp} (97%) rename tile_engine/ops/gemm_preshuffle/{benchmark_gemm_preshuffle_single.cpp => gemm_preshuffle_benchmark_single.cpp} (89%) diff --git a/Jenkinsfile b/Jenkinsfile index 7a8574df05..9acbbeeca2 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -12,6 +12,14 @@ def show_node_info() { """ } +// Error patterns to scan build logs for specific failure types and send detailed notifications. +def failurePatterns = [ + [pattern: /login attempt to .* failed with status: 401 Unauthorized/, description: "Docker registry authentication failed"], + [pattern: /docker login failed/, description: "Docker login failed"], + [pattern: /HTTP request sent .* 404 Not Found/, description: "HTTP request failed with 404"], + [pattern: /cat: .* No such file or directory/, description: "GPU not found"], +] + class Version { int major, minor, patch @Override @@ -1488,7 +1496,7 @@ pipeline { -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ -D GEMM_MULTI_D_DATATYPE="fp16" \ -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ - -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8" \ + -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \ -D GEMM_PRESHUFFLE_LAYOUT="rcr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ ninja -j64 benchmark_gemm_all && \ @@ -1528,7 +1536,7 @@ pipeline { -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ -D GEMM_MULTI_D_DATATYPE="fp16" \ -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ - -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8" \ + -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \ -D GEMM_PRESHUFFLE_LAYOUT="rcr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ ninja -j64 benchmark_gemm_all && \ @@ -1570,11 +1578,7 @@ pipeline { -DCMAKE_CXX_FLAGS=" -O3 " .. && \ ninja -j64 benchmark_gemm_all && \ python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \ - --warmup 5 --repeat 5 --verbose --json results.json && \ - ninja -j64 benchmark_gemm_fp16_rcr && \ - ninja -j64 benchmark_gemm_fp16_rrr && \ - ninja -j64 benchmark_gemm_fp16_crr && \ - ninja -j64 benchmark_gemm_fp16_ccr """ + --warmup 5 --repeat 5 --verbose --json results.json """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -1853,4 +1857,36 @@ pipeline { } } } + post { + failure { + node(rocmnode("nogpu")) { + script { + // Get the build log. + def buildLog = sh(script: 'wget -q --no-check-certificate -O - ' + BUILD_URL + 'consoleText', returnStdout: true) + // Check for patterns in the log. + def foundPatterns = [] + for (patternMap in failurePatterns) { + def result = checkForPattern(patternMap.pattern, buildLog) + if (result.found) { + foundPatterns.add([ + description: patternMap.description, + matchedLine: result.matchedLine, + context: result.context + ]) + } + } + // Send a notification for each matched failure pattern. + for (patternMap in foundPatterns) { + withCredentials([string(credentialsId: 'ck_ci_errors_webhook_url', variable: 'WEBHOOK_URL')]) { + sh ''' + curl -X POST "${WEBHOOK_URL}" \ + -H 'Content-Type: application/json' \ + -d '{"text": "\\n\\n**Build Failed**\\n\\n**Issues detected:** ''' + patternMap.description + '''\\n\\n**Log context:**\\n```\\n''' + patternMap.context.replace("'", "\\'") + '''\\n```\\n\\n**Job:** ''' + env.JOB_NAME + '''\\n\\n**Build:** #''' + env.BUILD_NUMBER + '''\\n\\n**URL:** ''' + env.RUN_DISPLAY_URL + '''"}' + ''' + } + } + } + } + } + } } diff --git a/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp b/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp index 46745fd02b..51922fde33 100644 --- a/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp +++ b/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp @@ -188,66 +188,42 @@ bool profile_gemm_multi_abd_impl(int do_verification, EDataType, remove_cvref_t>>::type; - auto get_a_matrix = [&]() -> auto { - // in case of pass through we avoid allocating a new - // tensor and copying values - if constexpr(is_same_v) + Tensor a_m_k({M, K}); + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) { - return as_m_k(Number<0>{}); + // result + auto data_refs1 = ck::tie(a_m_k(m, k)); + // inputs + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return as_m_k(Number{})(m, k); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(a_element_op, data_refs); } - else - { - Tensor a_m_k({M, K}); - for(int m = 0; m < M; ++m) - { - for(int k = 0; k < K; ++k) - { - // result - auto data_refs1 = ck::tie(a_m_k(m, k)); - // inputs - auto data_refs2 = - generate_tie([&](auto i) -> auto& { return as_m_k(Number{})(m, k); }, - Number{}); - auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); - unpack(a_element_op, data_refs); - } - } - return a_m_k; - } - }; + } using BComputeType = typename std::conditional<(NumBTensor > 1), EDataType, remove_cvref_t>>::type; - auto get_b_matrix = [&]() -> auto { - // in case of pass through we avoid allocating a new - // tensor and copying values - if constexpr(is_same_v) + Tensor b_k_n({K, N}); + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N; ++n) { - return bs_k_n(Number<0>{}); + // result + auto data_refs1 = ck::tie(b_k_n(k, n)); + // inputs + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return bs_k_n(Number{})(k, n); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(b_element_op, data_refs); } - else - { - Tensor b_k_n({K, N}); - for(int k = 0; k < K; ++k) - { - for(int n = 0; n < N; ++n) - { - // result - auto data_refs1 = ck::tie(b_k_n(k, n)); - // inputs - auto data_refs2 = - generate_tie([&](auto i) -> auto& { return bs_k_n(Number{})(k, n); }, - Number{}); - auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); - unpack(b_element_op, data_refs); - } - } - return b_k_n; - } - }; + } using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm float: def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool: """Check if a trait combination is valid.""" + if pipeline not in ["preshufflev2"]: + raise ValueError("Accepted pipeline values are: ['preshufflev2']") + if epilogue not in ["default", "cshuffle"]: + return ValueError("Accepted epilogue values are: ['default', 'cshuffle']") + if scheduler not in ["default"]: + return ValueError("Accepted scheduler values are: ['default']") return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS @@ -173,7 +173,7 @@ def validate_lds_capacity( matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) total_tile_in_lds = matrix_a_size + matrix_b_size - max_tile_size = 2**15 if pipeline == "compv4" else 2**16 + max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16 if total_tile_in_lds > max_tile_size: error_msg = ( @@ -266,6 +266,35 @@ def is_tile_config_valid( if warp_k * warp_tile_k > tile_k: return False + # Validate vector load alignment + m_iter_per_warp = tile_m / (warp_m * warp_tile_m) + vector_valid, vector_error = validate_vector_load_alignment( + warp_tile_m, + warp_tile_k, + a_datatype, + m_iter_per_warp, + wave_size=64, + vector_load_size=16, + ) + if not vector_valid: + logging.debug(f"Vector load alignment failed: {vector_error}") + return False + + # Validate M0, M1, M2 configuration for matrix A row-major layout + m0_m1_m2_valid, m0_m1_m2_error = validate_m0_m1_m2_configuration( + tile_m, + tile_k, + warp_m, + warp_n, + warp_k, + a_datatype, + vector_load_size=16, + warp_size=64, + ) + if not m0_m1_m2_valid: + logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}") + return False + # Validate warp configuration if not validate_warp_configuration(warp_m, warp_n, warp_k): logging.debug( @@ -318,12 +347,117 @@ def is_tile_config_valid( return True +def validate_vector_load_alignment( + wg_m: int, + wg_k: int, + a_datatype: str, + m_iter_per_warp: int, + wave_size: int, + vector_load_size: int, +) -> Tuple[bool, str]: + try: + # Calculate the memory access pattern size + a_element_size = element_size(a_datatype) + access_size = (wg_m * wg_k * a_element_size * m_iter_per_warp) / wave_size + + # Check if it's aligned to vector load size + if access_size % vector_load_size != 0: + error_msg = ( + f"Vector load alignment violation: " + f"({wg_m} * {wg_k} * {a_element_size} * {m_iter_per_warp} / {wave_size}) " + f"% {vector_load_size} = {access_size % vector_load_size} != 0. " + f"Access size: {access_size} bytes" + ) + return False, error_msg + + return True, "" + + except Exception as e: + return False, f"Error in vector load validation: {str(e)}" + + +def validate_m0_m1_m2_configuration( + tile_m: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + a_datatype: str, + vector_load_size: int = 16, + warp_size: int = 64, +) -> Tuple[bool, str]: + """ + Validate M0, M1, M2 configuration for matrix A row-major layout. + This ensures proper memory access pattern alignment. + """ + try: + # Validation for A as row-major + MPerBlock = tile_m + + # Calculate K1 using element size + K1 = vector_load_size / element_size(a_datatype) + + # Check if K1 is valid (must be integer) + if K1 != int(K1): + return ( + False, + f"K1 = {K1} is not an integer. vector_load_size({vector_load_size}) must be divisible by element_size({a_datatype})", + ) + K1 = int(K1) + + # Calculate K0 + if tile_k % K1 != 0: + return False, f"tile_k({tile_k}) must be divisible by K1({K1})" + K0 = tile_k // K1 + + # Calculate M2 + if warp_size % K0 != 0: + return False, f"warp_size({warp_size}) must be divisible by K0({K0})" + M2 = warp_size // K0 + + # Calculate number of warps and block size + NumWarps = warp_m * warp_n * warp_k + BlockSize = NumWarps * warp_size + + # Calculate M0 (assuming get_warp_size() returns warp_size) + M0 = BlockSize // warp_size # This should equal NumWarps + + # Calculate M1 + if (M2 * M0) == 0: + return False, f"M2({M2}) * M0({M0}) cannot be zero" + + if MPerBlock % (M2 * M0) != 0: + return ( + False, + f"MPerBlock({MPerBlock}) must be divisible by M2({M2}) * M0({M0}) = {M2 * M0}", + ) + M1 = MPerBlock // (M2 * M0) + + # Validate the assertion: M0 * M1 * M2 == MPerBlock + calculated_m_per_block = M0 * M1 * M2 + if calculated_m_per_block != MPerBlock: + error_msg = ( + f"Incorrect M0, M1, M2 configuration! " + f"M0({M0}) * M1({M1}) * M2({M2}) = {calculated_m_per_block} != MPerBlock({MPerBlock}). " + f"Configuration: K0={K0}, K1={K1}, NumWarps={NumWarps}, BlockSize={BlockSize}" + ) + return False, error_msg + + return True, "" + + except ZeroDivisionError as e: + return False, f"Division by zero in M0/M1/M2 calculation: {str(e)}" + except Exception as e: + return False, f"Error in M0/M1/M2 validation: {str(e)}" + + # [TODO] Handle this while moving code to commons Add more datatype to this function if needed def get_dtype_string(datatype: str) -> str: """Get C++ type string for datatype""" dtype_map = { "fp16": "ck_tile::fp16_t", "fp8": "ck_tile::fp8_t", + "bf8": "ck_tile::bf8_t", "bf16": "ck_tile::bf16_t", "fp32": "float", "fp64": "double", diff --git a/tile_engine/ops/gemm_preshuffle/configs/default_config.json b/tile_engine/ops/gemm_preshuffle/configs/default_config.json index d4c3537c65..4606cf0c27 100644 --- a/tile_engine/ops/gemm_preshuffle/configs/default_config.json +++ b/tile_engine/ops/gemm_preshuffle/configs/default_config.json @@ -1,62 +1,72 @@ { "tile_config": { "tile_m": { - "values": [ - 128 - ] + "max": 256, + "min": 64, + "step": 64 }, "tile_n": { - "values": [ - 128 - ] + "max": 256, + "min": 64, + "step": 64 }, "tile_k": { - "values": [ - 128 - ] + "max": 256, + "min": 64, + "step": 64 }, "warp_m": { - "values": [ - 1 - ] + "values": [ + 4, + 2, + 1 + ] }, "warp_n": { - "values": [ - 4 - ] + "values": [ + 4, + 2, + 1 + ] }, "warp_k": { - "values": [ - 1 - ] + "values": [ + 1 + ] }, "warp_tile_m": { - "values": [ - 16 - ] + "values": [ + 4, + 16, + 32 + ] }, "warp_tile_n": { - "values": [ - 16 - ] + "values": [ + 16, + 32, + 64 + ] }, "warp_tile_k": { - "values": [ - 16,32 - ] + "values": [ + 8, + 16, + 32, + 64, + 128 + ] } }, "trait_config": { "pipeline": { "values": [ - "preshufflev1", "preshufflev2" ] }, "scheduler": { "values": [ - "interwave", - "intrawave" + "default" ] }, "epilogue": { @@ -81,11 +91,12 @@ ] }, "persistent": { - "values": [ - true, - false - ] + "values": [ + true, + false + ] } }, - "k_block_per_cu": 2 + "k_block_per_cu": 1, + "permute_n": true } \ No newline at end of file diff --git a/tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json b/tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json index c0fc1f6cf8..cf7c79462e 100644 --- a/tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json +++ b/tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json @@ -2,27 +2,27 @@ "tile_config": { "tile_m": { "values": [ - 128 + 64 ] }, "tile_n": { "values": [ - 128 + 64 ] }, "tile_k": { "values": [ - 64 + 192 ] }, "warp_m": { "values": [ - 1 + 2 ] }, "warp_n": { "values": [ - 4 + 2 ] }, "warp_k": { @@ -42,7 +42,7 @@ }, "warp_tile_k": { "values": [ - 16,32 + 32 ] } }, @@ -54,12 +54,13 @@ }, "scheduler": { "values": [ - "intrawave" + "default" ] }, "epilogue": { "values": [ - "default" + "default", + "cshuffle" ] }, "pad_m": { @@ -79,9 +80,10 @@ }, "persistent": { "values": [ - false + true ] } }, - "k_block_per_cu": 8 + "k_block_per_cu": 1, + "permute_n": false } \ No newline at end of file diff --git a/tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp similarity index 97% rename from tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle.hpp rename to tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp index 74fccf6bf2..77a9f26527 100644 --- a/tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp @@ -23,6 +23,14 @@ inline constexpr auto get_metric_name(Metric m) } } +struct KernelConfig +{ + std::tuple tile_dims; + std::tuple warp_dims; + std::tuple warp_tile_dims; + bool permuteN; +}; + struct GemmProblem { int split_k_; diff --git a/tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle_single.cpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp similarity index 89% rename from tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle_single.cpp rename to tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp index 152e27e77e..1f03d1cf9b 100644 --- a/tile_engine/ops/gemm_preshuffle/benchmark_gemm_preshuffle_single.cpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp @@ -75,7 +75,7 @@ inline auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -void benchmark_gemm_preshuffle_single(const ck_tile::ArgParser& arg_parser) +void benchmark_single(const ck_tile::ArgParser& arg_parser) { // Use DataTypeTraits to get the actual type names from the generated header // The generated header defines ADataType, BDataType, AccDataType, CDataType @@ -124,9 +124,16 @@ void benchmark_gemm_preshuffle_single(const ck_tile::ArgParser& arg_parser) try { // Create a lambda that wraps the kernel launch - std::tuple warp_tile_dims = std::make_tuple( SelectedKernel::WarpTileM, SelectedKernel::WarpTileN, SelectedKernel::WarpTileK); + std::tuple tile_dims = + std::make_tuple(SelectedKernel::TileM, SelectedKernel::TileN, SelectedKernel::TileK); + std::tuple warp_dims = std::make_tuple(SelectedKernel::WarpPerBlock_M, + SelectedKernel::WarpPerBlock_N, + SelectedKernel::WarpPerBlock_K); + bool permuteN = SelectedKernel::PermuteN; + + KernelConfig config{tile_dims, warp_dims, warp_tile_dims, permuteN}; auto kernel_func = [](const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) { @@ -134,7 +141,7 @@ void benchmark_gemm_preshuffle_single(const ck_tile::ArgParser& arg_parser) }; // Benchmark the kernel - profiler.benchmark(gemm_problem, kernel_func, warp_tile_dims); + profiler.benchmark(gemm_problem, kernel_func, config); // Select best instance based on metric profiler.select_best_instance(static_cast(arg_parser.get_int("metric"))); @@ -153,7 +160,7 @@ int main(int argc, char* argv[]) if(!result) return EXIT_FAILURE; - benchmark_gemm_preshuffle_single(parser); + benchmark_single(parser); return 0; } catch(const std::exception& e) diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp index 4fb98dc3c2..09ec895ab5 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp @@ -75,58 +75,6 @@ constexpr auto is_row_major(Layout) return ck_tile::bool_constant>{}; } -// // Permutation function for pk_int4_t -// template -// void permute_vectors_i4x4_b(Tensor& tensor) -// { -// const ck_tile::index_t K = tensor.get_length(0); -// const ck_tile::index_t N = tensor.get_length(1); -// // vector pk_i4x4 permute -// for(int i = 0; i < N; i++) -// { -// for(int j = 0; j < K; j += 8) -// { -// int8_t input[8]; - -// for(int k = 0; k < 4; k++) -// { -// int8_t i4x2 = tensor(j + k * 2, i).data; -// input[k * 2 + 0] = (i4x2 >> 4) & 0xf; -// input[k * 2 + 1] = (i4x2 >> 0) & 0xf; -// } - -// // permute 01234567->20643175 -// { -// int8_t hi = input[2]; -// int8_t lo = input[0]; -// int8_t i4x2 = (hi << 4) | lo; -// tensor(j + 0, i) = i4x2; -// } - -// { -// int8_t hi = input[6]; -// int8_t lo = input[4]; -// int8_t i4x2 = (hi << 4) | lo; -// tensor(j + 2, i) = i4x2; -// } - -// { -// int8_t hi = input[3]; -// int8_t lo = input[1]; -// int8_t i4x2 = (hi << 4) | lo; -// tensor(j + 4, i) = i4x2; -// } - -// { -// int8_t hi = input[7]; -// int8_t lo = input[5]; -// int8_t i4x2 = (hi << 4) | lo; -// tensor(j + 6, i) = i4x2; -// } -// } -// } -// } - // Structure to hold kernel traits for dispatcher struct KernelTraits { @@ -211,3 +159,27 @@ auto shuffle_b(const ck_tile::HostTensor& t, std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } + +template +auto shuffle_b_permuteN(const ck_tile::HostTensor& t, + ck_tile::index_t N_Warp_Tile, + ck_tile::index_t K_Warp_Tile, + ck_tile::index_t N_Tile, + ck_tile::index_t N_Warp) +{ + assert(t.get_lengths().size() == 2); + + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + int divisor = N_Warp_Tile == 32 ? 2 : 4; + int NRepeat = N_Tile / N_Warp_Tile / N_Warp; + ck_tile::HostTensor t_view({n_ / N_Tile, + N_Warp, + N_Warp_Tile, + NRepeat, + k_ / K_Warp_Tile, + divisor, + K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); +} diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py index e6e075cb36..1d4b027716 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py @@ -95,67 +95,87 @@ class GemmPreshuffleKernelBuilder: def _get_tile_configs(self, fast_mode=False): """Get tile configurations for the current datatype and layout""" - if "tile_configs" in self.config: - # Old format - return ( - self.config["tile_configs"].get(self.datatype, {}).get(self.layout, []) + + tile_config = self.config["tile_config"] + + # Generate values in the config if default range is given + if tile_config.get("tile_m").get("values") is None: + tile_config.get("tile_m")["values"] = self._generate_values( + tile_config.get("tile_m").get("min"), + tile_config.get("tile_m").get("max"), + tile_config.get("tile_m").get("step"), + ) + if tile_config.get("tile_n").get("values") is None: + tile_config.get("tile_n")["values"] = self._generate_values( + tile_config.get("tile_n").get("min"), + tile_config.get("tile_n").get("max"), + tile_config.get("tile_n").get("step"), + ) + if tile_config.get("tile_k").get("values") is None: + tile_config.get("tile_k")["values"] = self._generate_values( + tile_config.get("tile_k").get("min"), + tile_config.get("tile_k").get("max"), + tile_config.get("tile_k").get("step"), ) - elif "tile_config" in self.config: - # New format - generate combinations from individual parameter values - tile_config = self.config["tile_config"] - # Get all possible values for each parameter - tile_m_values = tile_config.get("tile_m", {}).get("values", [256]) - tile_n_values = tile_config.get("tile_n", {}).get("values", [256]) - tile_k_values = tile_config.get("tile_k", {}).get("values", [32]) - warp_m_values = tile_config.get("warp_m", {}).get("values", [2]) - warp_n_values = tile_config.get("warp_n", {}).get("values", [2]) - warp_k_values = tile_config.get("warp_k", {}).get("values", [1]) - warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [32]) - warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [32]) - warp_tile_k_values = tile_config.get("warp_tile_k", {}).get("values", [32]) + # Get all possible values for each parameter + tile_m_values = tile_config.get("tile_m").get("values") + tile_n_values = tile_config.get("tile_n").get("values") + tile_k_values = tile_config.get("tile_k").get("values") + warp_m_values = tile_config.get("warp_m").get("values") + warp_n_values = tile_config.get("warp_n").get("values") + warp_k_values = tile_config.get("warp_k").get("values") + warp_tile_m_values = tile_config.get("warp_tile_m").get("values") + warp_tile_n_values = tile_config.get("warp_tile_n").get("values") + warp_tile_k_values = tile_config.get("warp_tile_k").get("values") - # Generate all combinations - configs = [] - for tile_m in tile_m_values: - for tile_n in tile_n_values: - for tile_k in tile_k_values: - for warp_m in warp_m_values: - for warp_n in warp_n_values: - for warp_k in warp_k_values: - for warp_tile_m in warp_tile_m_values: - for warp_tile_n in warp_tile_n_values: - for warp_tile_k in warp_tile_k_values: - # Validate configuration - if self._validate_tile_config( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - fast_mode=fast_mode, - ): - configs.append( - { - "tile_m": tile_m, - "tile_n": tile_n, - "tile_k": tile_k, - "warp_m": warp_m, - "warp_n": warp_n, - "warp_k": warp_k, - "warp_tile_m": warp_tile_m, - "warp_tile_n": warp_tile_n, - "warp_tile_k": warp_tile_k, - } - ) - return configs - else: - # Fallback to default - return [] + # Generate all combinations + configs = [] + for tile_m in tile_m_values: + for tile_n in tile_n_values: + for tile_k in tile_k_values: + for warp_m in warp_m_values: + for warp_n in warp_n_values: + for warp_k in warp_k_values: + for warp_tile_m in warp_tile_m_values: + for warp_tile_n in warp_tile_n_values: + for warp_tile_k in warp_tile_k_values: + # Validate configuration + if self._validate_tile_config( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + fast_mode=fast_mode, + ): + configs.append( + { + "tile_m": tile_m, + "tile_n": tile_n, + "tile_k": tile_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "warp_tile_m": warp_tile_m, + "warp_tile_n": warp_tile_n, + "warp_tile_k": warp_tile_k, + } + ) + return configs + + def _generate_values(self, min_val, max_val, step): + """Generate a list of values from min to max with the given step""" + values = [] + val = min_val + while val <= max_val: + values.append(val) + val += step + return values def _generate_trait_combinations(self): """Generate all combinations of traits""" @@ -270,6 +290,12 @@ class GemmPreshuffleKernelBuilder: return True else: + # Validate preshuffle specific constraints + if self.config.get("permute_n"): + valid = (tile_n / warp_tile_n / warp_n) % 2 == 0 + if not valid: + return False + # Full validation for generation # Determine data types for validation a_datatype = self.datatype @@ -299,7 +325,7 @@ class GemmPreshuffleKernelBuilder: ) def _generate_kernel_instance( - self, tile_config, trait_combo, k_block_per_cu, is_header=True + self, tile_config, trait_combo, k_block_per_cu, permute_n, is_header=True ): """Generate a single kernel instance""" ( @@ -349,9 +375,9 @@ class GemmPreshuffleKernelBuilder: acc_type = "float" # Determine output type - c_type = get_dtype_string(self.datatype) + c_type = self.datatype if self.datatype in ["fp8", "bf8"]: - c_type = "ck_tile::fp16_t" + c_type = "fp16" # Determine layouts based on self.layout a_layout, b_layout, c_layout = get_abc_layouts(self.layout) @@ -374,7 +400,7 @@ class GemmPreshuffleKernelBuilder: using ADataType = {get_dtype_string(self.datatype)}; using BDataType = {get_dtype_string(self.datatype)}; using AccDataType = {acc_type}; -using CDataType = {c_type}; +using CDataType = {get_dtype_string(c_type)}; using ALayout = {a_layout}; using BLayout = {b_layout}; @@ -408,6 +434,8 @@ struct SelectedKernel {{ static constexpr bool Preshuffle = true; static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool PermuteN = {"true" if permute_n else "false"}; + // Tile shape using TileShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -485,7 +513,10 @@ struct SelectedKernel {{ WarpTileK, // KPerXdl_ TransposeC, // isCTransposed_ memory_operation, // MemoryOperation_ - NumWaveGroups>; // kNumWaveGroups_ + NumWaveGroups, // kNumWaveGroups_ + false, // FixedVectorSize_ + 1, // VectorSizeC_ + PermuteN>; // isPermuteN_ using GemmEpilogue = ck_tile::CShuffleEpilogue; """ @@ -580,6 +611,7 @@ struct SelectedKernel {{ tile_configs = self._get_tile_configs() trait_combos = self._generate_trait_combinations() k_block_per_cu = self.config.get("k_block_per_cu") + permute_n = self.config.get("permute_n") # Prepare work items for parallel processing work_items = [] @@ -590,6 +622,7 @@ struct SelectedKernel {{ tile_config, trait_combo, k_block_per_cu, + permute_n, self.working_path, self.datatype, self.layout, @@ -681,21 +714,29 @@ struct SelectedKernel {{ def _generate_single_kernel_individual(work_item): """Worker function to generate a single individual kernel file""" - tile_config, trait_combo, k_block_per_cu, working_path, datatype, layout = work_item + ( + tile_config, + trait_combo, + k_block_per_cu, + permute_n, + working_path, + datatype, + layout, + ) = work_item # Create a temporary builder instance for this worker builder = GemmPreshuffleKernelBuilder(working_path, datatype, layout) try: kernel_name, instance_code = builder._generate_kernel_instance( - tile_config, trait_combo, k_block_per_cu + tile_config, trait_combo, k_block_per_cu, permute_n ) - # Create simplified filename without the "gemm_" prefix - # Remove "gemm_" from the beginning of kernel_name for the filename + # Create simplified filename without the "gemm_preshuffle_" prefix + # Remove "gemm_preshuffle_" from the beginning of kernel_name for the filename simplified_name = kernel_name - if simplified_name.startswith("gemm_"): - simplified_name = simplified_name[5:] # Remove "gemm_" prefix + if simplified_name.startswith("gemm_preshuffle_"): + simplified_name = simplified_name[16:] # Remove "gemm_preshuffle_" prefix # Write individual header file header_file = working_path / f"gemm_single_{simplified_name}.hpp" @@ -727,7 +768,7 @@ def main(): parser.add_argument( "--layout", required=True, - choices=["rcr", "rrr", "ccr", "crr"], + choices=["rcr"], help="Matrix layout", ) parser.add_argument("--config_json", required=True, help="Configuration JSON file") @@ -735,7 +776,9 @@ def main(): "--num_workers", type=int, help="Number of parallel workers (default: auto)" ) parser.add_argument( - "--gen_individual", action="store_true", help="Generate individual kernel files" + "--gen_all_individual", + action="store_true", + help="Generate individual kernel files", ) parser.add_argument( "--gen_single", action="store_true", help="Generate a single kernel file" @@ -763,7 +806,7 @@ def main(): assert len(layout_parts) == 3, ( f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)" ) - assert layout_parts[0] == "r" and layout_parts[1] == "c", ( + assert layout_parts[0] in ["r"] and layout_parts[1] in ["c"], ( f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a must be 'r' for row major and matrix_b must be 'c' for column major as it is the only supported layout for preshuffle)" ) assert layout_parts[2] == "r", ( @@ -816,10 +859,11 @@ def main(): ) k_block_per_cu = builder.config.get("k_block_per_cu") + permute_n = builder.config.get("permute_n") # Generate the kernel kernel_name, instance_code = builder._generate_kernel_instance( - tile_config, trait_combo, k_block_per_cu + tile_config, trait_combo, k_block_per_cu, permute_n ) # Write the file @@ -835,13 +879,13 @@ def main(): print(f"Generated {header_file}") - elif args.gen_individual: + elif args.gen_all_individual: # Generate all individual kernel files builder.run(args.num_workers) pass else: parser.error( - "Must specify one of: --list_kernels, --gen_individual, or --gen_single" + "Must specify one of: --list_kernels, --gen_all_individual, or --gen_single" ) diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp index 4f2a929ba0..7d212c934c 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp @@ -2,7 +2,7 @@ #include "ck_tile/host/device_prop.hpp" #include "ck_tile/ops/gemm.hpp" -#include "benchmark_gemm_preshuffle.hpp" +#include "gemm_preshuffle_benchmark.hpp" class GemmProfiler { @@ -17,7 +17,7 @@ class GemmProfiler void benchmark(GemmProblem& gemm_problem, std::function kernel_func, - const std::tuple& warp_tile_dims) + KernelConfig& config) { // Create a vector with a single callable that returns both name and time std::vector(ck_tile::GemmHostArgs&, @@ -30,13 +30,13 @@ class GemmProfiler return std::make_tuple(std::string(KERNEL_NAME), time); }); - benchmark(gemm_problem, callables, warp_tile_dims); + benchmark(gemm_problem, callables, config); } void benchmark(GemmProblem& gemm_problem, std::vector( ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables, - const std::tuple& warp_tile_dims) + KernelConfig& config) { const ALayout layout_a = ALayout{}; const BLayout layout_b = BLayout{}; @@ -110,11 +110,22 @@ class GemmProfiler for(const auto& callable : callables) { - ck_tile::index_t N_Warp_Tile = std::get<1>(warp_tile_dims); - ck_tile::index_t K_Warp_Tile = std::get<2>(warp_tile_dims); + ck_tile::index_t N_Warp_Tile = std::get<1>(config.warp_tile_dims); + ck_tile::index_t K_Warp_Tile = std::get<2>(config.warp_tile_dims); + ck_tile::index_t N_Tile = std::get<1>(config.tile_dims); + ck_tile::index_t N_Warp = std::get<1>(config.warp_dims); + + ck_tile::HostTensor b_shuffle_host = [&]() { + if(config.permuteN) + { + return shuffle_b_permuteN(b_k_n, N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp); + } + else + { + return shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile); + } + }(); - ck_tile::HostTensor b_shuffle_host = - shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile); b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); ck_tile::GemmHostArgs gemm_args = {