From 799b096f50b0c81408a69153e995d011ea4a6043 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 24 Feb 2026 22:40:48 +0100 Subject: [PATCH] [CK][CK TILE] Improve oob check (#4791) ## Motivation Improve OOB checks. Remove permutes which have been generated by thread buffer zero clear. at now in assembly there is only condmask instead of permute + condmask. Change number of KPack for generated instances ## Technical Details Remove permute instructions from assembly ## Test Plan test_grouped_convnd_fwd_tile ## Test Result passed ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: jakpiase --- .../generate_instances.py | 16 ++- .../core/arch/amd_buffer_addressing.hpp | 95 +++++++++++++++--- .../arch/amd_buffer_addressing_builtins.hpp | 98 +++++++++++++++---- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 6 ++ .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 1 + .../src/profile_grouped_conv_fwd_tile.cpp | 2 +- 6 files changed, 187 insertions(+), 31 deletions(-) diff --git a/experimental/grouped_convolution_tile_instances/generate_instances.py b/experimental/grouped_convolution_tile_instances/generate_instances.py index 3216884a70..37c2db8a7b 100644 --- a/experimental/grouped_convolution_tile_instances/generate_instances.py +++ b/experimental/grouped_convolution_tile_instances/generate_instances.py @@ -197,7 +197,21 @@ def parse_fwd_instances(instances, problem_name): dtype = get_dtype(problem_name) # TODO: Make it more flexible # k_per_xdl = f"ck_tile::get_k_warp_tile<{dtype}, {m_per_xdl}>()" - k_per_xdl = 8 if dtype == "float" else 16 + if dtype == "float": + if m_per_xdl == 32: + if instance.find("BlkGemmPipelineVersion") == -1: + k_per_xdl = 4 + else: + # Increase for universal gemm + k_per_xdl = 8 + else: + k_per_xdl = 8 + else: + if m_per_xdl == 32: + k_per_xdl = 16 + else: + k_per_xdl = 32 + k_per_xdl = min(k_per_xdl, k_per_block) conv = ConvInstanceTemplateParams( spec, diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index f1aba16645..246b2b85a7 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -2578,6 +2578,9 @@ CK_TILE_DEVICE void amd_buffer_atomic_max_impl(const thread_buffer src_thr } } +template +using has_type = typename T::type; + // buffer_load requires: // 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. @@ -2608,12 +2611,48 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, return amd_buffer_load_impl( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); #else - thread_buffer tmp = - amd_buffer_load_impl(src_wave_buffer_resource, src_thread_addr_offset, 0); if constexpr(oob_conditional_check) - return src_thread_element_valid ? tmp : thread_buffer{numeric::zero()}; - else - return tmp; + { + if(!src_thread_element_valid) + { + if constexpr(is_detected::value) + { + // Use vector_t for not valid elements to avoid permute instructions. + // Get raw type from structure + using vector_t = typename T::type __attribute__((ext_vector_type(N))); + if constexpr(sizeof(vector_t) != sizeof(typename T::type) * N) + { + // Not possible to use set_as + return thread_buffer{numeric::zero()}; + } + else + { + thread_buffer tmp; + tmp.template set_as(number<0>{}, + vector_t{numeric::zero()}); + return tmp; + } + } + else + { + // Use vector_t for not valid elements to avoid permute instructions. + using vector_t = T __attribute__((ext_vector_type(N))); + if constexpr(sizeof(vector_t) != sizeof(T) * N) + { + // Not possible to use set_as + return thread_buffer{numeric::zero()}; + } + else + { + thread_buffer tmp; + tmp.template set_as(number<0>{}, vector_t{numeric::zero()}); + return tmp; + } + } + } + } + return amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); #endif } @@ -2637,13 +2676,47 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - thread_buffer tmp = - amd_buffer_load_impl(src_wave_buffer_resource, src_thread_addr_offset, 0); - if constexpr(oob_conditional_check) - return src_thread_element_valid ? tmp : thread_buffer{customized_value}; - else - return tmp; + { + if(!src_thread_element_valid) + { + if constexpr(is_detected::value) + { + // Use vector_t for not valid elements to avoid permute instructions. + // Get raw type from structure + using vector_t = typename T::type __attribute__((ext_vector_type(N))); + if constexpr(sizeof(vector_t) != sizeof(typename T::type) * N) + { + // Not possible to use set_as + return thread_buffer{customized_value}; + } + else + { + thread_buffer tmp; + tmp.template set_as(number<0>{}, vector_t{customized_value}); + return tmp; + } + } + else + { + // Use vector_t for not valid elements to avoid permute instructions. + using vector_t = T __attribute__((ext_vector_type(N))); + if constexpr(sizeof(vector_t) != sizeof(T) * N) + { + // Not possible to use set_as + return thread_buffer{customized_value}; + } + else + { + thread_buffer tmp; + tmp.template set_as(number<0>{}, vector_t{customized_value}); + return tmp; + } + } + } + } + return amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); } template src_thr } } +template +using has_type = typename T::type; + // buffer_load requires: // 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. @@ -2436,21 +2439,46 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, #else if constexpr(oob_conditional_check) { - if(src_thread_element_valid) + if(!src_thread_element_valid) { - return amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); - } - else - { - return thread_buffer{numeric::zero()}; + if constexpr(is_detected::value) + { + // Use vector_t for not valid elements to avoid permute instructions. + // Get raw type from structure + using vector_t = typename T::type __attribute__((ext_vector_type(N))); + if constexpr(sizeof(vector_t) != sizeof(typename T::type) * N) + { + // Not possible to use set_as + return thread_buffer{numeric::zero()}; + } + else + { + thread_buffer tmp; + tmp.template set_as(number<0>{}, + vector_t{numeric::zero()}); + return tmp; + } + } + else + { + // Use vector_t for not valid elements to avoid permute instructions. + using vector_t = T __attribute__((ext_vector_type(N))); + if constexpr(sizeof(vector_t) != sizeof(T) * N) + { + // Not possible to use set_as + return thread_buffer{numeric::zero()}; + } + else + { + thread_buffer tmp; + tmp.template set_as(number<0>{}, vector_t{numeric::zero()}); + return tmp; + } + } } } - else - { - return amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); - } + return amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); #endif } @@ -2474,13 +2502,47 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - thread_buffer tmp = - amd_buffer_load_impl(src_wave_buffer_resource, src_thread_addr_offset, 0); - if constexpr(oob_conditional_check) - return src_thread_element_valid ? tmp : thread_buffer{customized_value}; - else - return tmp; + { + if(!src_thread_element_valid) + { + if constexpr(is_detected::value) + { + // Use vector_t for not valid elements to avoid permute instructions. + // Get raw type from structure + using vector_t = typename T::type __attribute__((ext_vector_type(N))); + if constexpr(sizeof(vector_t) != sizeof(typename T::type) * N) + { + // Not possible to use set_as + return thread_buffer{customized_value}; + } + else + { + thread_buffer tmp; + tmp.template set_as(number<0>{}, vector_t{customized_value}); + return tmp; + } + } + else + { + // Use vector_t for not valid elements to avoid permute instructions. + using vector_t = T __attribute__((ext_vector_type(N))); + if constexpr(sizeof(vector_t) != sizeof(T) * N) + { + // Not possible to use set_as + return thread_buffer{customized_value}; + } + else + { + thread_buffer tmp; + tmp.template set_as(number<0>{}, vector_t{customized_value}); + return tmp; + } + } + } + } + return amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); } template >; +template +using WarpGemmMfmaF32F32F32M32N32K4 = WarpGemmImpl, + 2, + AttrNumAccess>>; + template using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution = WarpGemmImpl struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K4; }; template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K8<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M32N32K4<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M32N32K8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; }; // fp16 diff --git a/profiler/src/profile_grouped_conv_fwd_tile.cpp b/profiler/src/profile_grouped_conv_fwd_tile.cpp index 1a1e8b769a..2c436abb8f 100644 --- a/profiler/src/profile_grouped_conv_fwd_tile.cpp +++ b/profiler/src/profile_grouped_conv_fwd_tile.cpp @@ -96,7 +96,7 @@ int call_profiler(const ckt::Args& args, bool time_kernel) std::string op_name; bool valid; std::tie(valid, avg_time, op_name) = ckp::run_grouped_conv_forward_tile_algs( - args, inputs.get(), outputs.get(), ck_tile::stream_config{nullptr, time_kernel}); + args, inputs.get(), outputs.get(), ck_tile::stream_config{nullptr, time_kernel, 0, 5, 50}); if(time_kernel) { std::cout << "Best configuration parameters:" << "\nname: " << op_name