[rocm-libraries] ROCm/rocm-libraries#6399 (commit 13bf528)

[CK][CK TILE] Modify elementwise kernel template signature to
 accept independent type arguments (#6399)

## Motivation

modify elementwise kernel template signature to fix cshuffle epilogue
build error

## Technical Details

Encountered a build error while building conv fallback kernel with
dispatcher.
Error: Type mismatch in `ElementWiseKernel::operator()` where the
template required all three parameters (lens, input_strides,
output_strides) to be the same type, but the CShuffle epilogue was
passing them with different tuple element types.

Solution: Modified the template signature in elementwise_kernel.hpp to
accept three independent type parameters:

Changed from single typename `Dims` to typename `DimsLens`, typename
`DimsInStrides`, typename `DimsOutStrides`
Updated references to `Dims::size()` to use the appropriate specific
type

## Test Plan

- Test with dispatcher conv unit tests
- Relying on CI tests

## Test Result
- Dispatcher unit tests passed
- Relying on CI tests

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Yaswanth Raparti
2026-04-14 07:45:14 +00:00
committed by assistant-librarian[bot]
parent 918e8a1bd8
commit 9491563725

View File

@@ -27,10 +27,13 @@ struct ElementWiseKernel
return is_wave32() ? kBlockSize / 2 : kBlockSize;
}
template <typename... XDataType, typename Dims>
CK_TILE_DEVICE void operator()(const Dims lens,
const Dims input_strides,
const Dims output_strides,
template <typename... XDataType,
typename DimsLens,
typename DimsInStrides,
typename DimsOutStrides>
CK_TILE_DEVICE void operator()(const DimsLens lens,
const DimsInStrides input_strides,
const DimsOutStrides output_strides,
const tuple<XDataType...>& input_tensors,
YDataType* p_y) const
{
@@ -49,10 +52,11 @@ struct ElementWiseKernel
input_tensors.get(i), lens, input_strides, number<S::kVectorM>{}, number<1>{});
const auto transformed_tensor = pad_tensor_view(
transform_tensor_view(tensor_view,
ck_tile::make_tuple(merge_transform),
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
ck_tile::make_tuple(sequence<0>{})),
transform_tensor_view(
tensor_view,
ck_tile::make_tuple(merge_transform),
ck_tile::make_tuple(make_index_sequence<DimsLens::size()>{}),
ck_tile::make_tuple(sequence<0>{})),
ck_tile::make_tuple(number<S::kBlockM>{}),
sequence<Problem::kPad>{});
@@ -86,13 +90,14 @@ struct ElementWiseKernel
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
p_y, lens, output_strides, number<S::kVectorM>{});
const auto transformed_y_m_n = pad_tensor_view(
transform_tensor_view(y_m_n,
ck_tile::make_tuple(merge_transform),
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
ck_tile::make_tuple(sequence<0>{})),
ck_tile::make_tuple(number<S::kBlockM>{}),
sequence<Problem::kPad>{});
const auto transformed_y_m_n =
pad_tensor_view(transform_tensor_view(
y_m_n,
ck_tile::make_tuple(merge_transform),
ck_tile::make_tuple(make_index_sequence<DimsOutStrides::size()>{}),
ck_tile::make_tuple(sequence<0>{})),
ck_tile::make_tuple(number<S::kBlockM>{}),
sequence<Problem::kPad>{});
auto y_window = make_tile_window(transformed_y_m_n,
make_tuple(number<S::kBlockM>{}),