From d988d552756f7dbffba3ba671d32eeb61299617c Mon Sep 17 00:00:00 2001 From: Yaswanth Raparti <113389104+yraparti@users.noreply.github.com> Date: Tue, 14 Apr 2026 00:44:27 -0700 Subject: [PATCH] [CK][CK TILE] Modify elementwise kernel template signature to accept independent type arguments (#6399) ## Motivation modify elementwise kernel template signature to fix cshuffle epilogue build error ## Technical Details Encountered a build error while building conv fallback kernel with dispatcher. Error: Type mismatch in `ElementWiseKernel::operator()` where the template required all three parameters (lens, input_strides, output_strides) to be the same type, but the CShuffle epilogue was passing them with different tuple element types. Solution: Modified the template signature in elementwise_kernel.hpp to accept three independent type parameters: Changed from single typename `Dims` to typename `DimsLens`, typename `DimsInStrides`, typename `DimsOutStrides` Updated references to `Dims::size()` to use the appropriate specific type ## Test Plan - Test with dispatcher conv unit tests - Relying on CI tests ## Test Result - Dispatcher unit tests passed - Relying on CI tests ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../elementwise/kernel/elementwise_kernel.hpp | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp index a4dd791b83..d9d3897101 100644 --- a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp +++ b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp @@ -27,10 +27,13 @@ struct ElementWiseKernel return is_wave32() ? kBlockSize / 2 : kBlockSize; } - template - CK_TILE_DEVICE void operator()(const Dims lens, - const Dims input_strides, - const Dims output_strides, + template + CK_TILE_DEVICE void operator()(const DimsLens lens, + const DimsInStrides input_strides, + const DimsOutStrides output_strides, const tuple& input_tensors, YDataType* p_y) const { @@ -49,10 +52,11 @@ struct ElementWiseKernel input_tensors.get(i), lens, input_strides, number{}, number<1>{}); const auto transformed_tensor = pad_tensor_view( - transform_tensor_view(tensor_view, - ck_tile::make_tuple(merge_transform), - ck_tile::make_tuple(make_index_sequence{}), - ck_tile::make_tuple(sequence<0>{})), + transform_tensor_view( + tensor_view, + ck_tile::make_tuple(merge_transform), + ck_tile::make_tuple(make_index_sequence{}), + ck_tile::make_tuple(sequence<0>{})), ck_tile::make_tuple(number{}), sequence{}); @@ -86,13 +90,14 @@ struct ElementWiseKernel const auto y_m_n = make_naive_tensor_view( p_y, lens, output_strides, number{}); - const auto transformed_y_m_n = pad_tensor_view( - transform_tensor_view(y_m_n, - ck_tile::make_tuple(merge_transform), - ck_tile::make_tuple(make_index_sequence{}), - ck_tile::make_tuple(sequence<0>{})), - ck_tile::make_tuple(number{}), - sequence{}); + const auto transformed_y_m_n = + pad_tensor_view(transform_tensor_view( + y_m_n, + ck_tile::make_tuple(merge_transform), + ck_tile::make_tuple(make_index_sequence{}), + ck_tile::make_tuple(sequence<0>{})), + ck_tile::make_tuple(number{}), + sequence{}); auto y_window = make_tile_window(transformed_y_m_n, make_tuple(number{}),