diff --git a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp index a4dd791b83..d9d3897101 100644 --- a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp +++ b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp @@ -27,10 +27,13 @@ struct ElementWiseKernel return is_wave32() ? kBlockSize / 2 : kBlockSize; } - template - CK_TILE_DEVICE void operator()(const Dims lens, - const Dims input_strides, - const Dims output_strides, + template + CK_TILE_DEVICE void operator()(const DimsLens lens, + const DimsInStrides input_strides, + const DimsOutStrides output_strides, const tuple& input_tensors, YDataType* p_y) const { @@ -49,10 +52,11 @@ struct ElementWiseKernel input_tensors.get(i), lens, input_strides, number{}, number<1>{}); const auto transformed_tensor = pad_tensor_view( - transform_tensor_view(tensor_view, - ck_tile::make_tuple(merge_transform), - ck_tile::make_tuple(make_index_sequence{}), - ck_tile::make_tuple(sequence<0>{})), + transform_tensor_view( + tensor_view, + ck_tile::make_tuple(merge_transform), + ck_tile::make_tuple(make_index_sequence{}), + ck_tile::make_tuple(sequence<0>{})), ck_tile::make_tuple(number{}), sequence{}); @@ -86,13 +90,14 @@ struct ElementWiseKernel const auto y_m_n = make_naive_tensor_view( p_y, lens, output_strides, number{}); - const auto transformed_y_m_n = pad_tensor_view( - transform_tensor_view(y_m_n, - ck_tile::make_tuple(merge_transform), - ck_tile::make_tuple(make_index_sequence{}), - ck_tile::make_tuple(sequence<0>{})), - ck_tile::make_tuple(number{}), - sequence{}); + const auto transformed_y_m_n = + pad_tensor_view(transform_tensor_view( + y_m_n, + ck_tile::make_tuple(merge_transform), + ck_tile::make_tuple(make_index_sequence{}), + ck_tile::make_tuple(sequence<0>{})), + ck_tile::make_tuple(number{}), + sequence{}); auto y_window = make_tile_window(transformed_y_m_n, make_tuple(number{}),