diff --git a/include/ck_tile/core/utility/reduce_operator.hpp b/include/ck_tile/core/utility/reduce_operator.hpp index 2d7ac78b06..a698c91e45 100644 --- a/include/ck_tile/core/utility/reduce_operator.hpp +++ b/include/ck_tile/core/utility/reduce_operator.hpp @@ -35,8 +35,6 @@ struct Add return type_convert(y_ + x_); } - - static constexpr bool requires_special_combine = false; }; struct SquareAdd @@ -64,28 +62,6 @@ struct SquareAdd float x_ = type_convert(x); return type_convert(y_ + (x_ * x_)); } - - // For combining partial results - template || std::is_same_v || - std::is_same_v || std::is_same_v>> - CK_TILE_HOST_DEVICE constexpr T combine_partial_results(const T& partial1, - const T& partial2) const - { - return partial1 + partial2; // Just add the partial sums, don't square again - } - - template || std::is_same_v || - std::is_same_v || std::is_same_v>> - CK_TILE_HOST_DEVICE constexpr T combine_partial_results(T& partial1, T& partial2) const - { - float partial1_ = type_convert(partial1); - float partial2_ = type_convert(partial2); - return type_convert(partial1_ + partial2_); - } - - static constexpr bool requires_special_combine = true; }; struct Max @@ -109,8 +85,6 @@ struct Max { return max(y, x); } - - static constexpr bool requires_special_combine = false; }; struct AbsMax @@ -134,8 +108,6 @@ struct AbsMax { return max(y, abs(x)); } - - static constexpr bool requires_special_combine = false; }; } // namespace ReduceOp diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index 849fa6c252..b72657b785 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -183,16 +183,7 @@ struct BlockReduce2dSync // pull data from remote lane const auto v_remote = warp_shuffle(v_local, src_lane); - - // For reduce, use combine_partial_results for operations that require it - if constexpr(ReduceFunc::requires_special_combine) - { - v_local = reduce_func.combine_partial_results(v_local, v_remote); - } - else - { - v_local = reduce_func(v_local, v_remote); - } + v_local = reduce_func(v_local, v_remote); }); } }); @@ -309,16 +300,7 @@ struct BlockReduce2dCrossWarpSync static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) { constexpr auto i_1 = number{}; const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1]; - - // For reduce, use combine_partial_results for operations that require it - if constexpr(ReduceFunc::requires_special_combine) - { - v_local = reduce_func.combine_partial_results(v_local, v_remote); - } - else - { - v_local = reduce_func(v_local, v_remote); - } + v_local = reduce_func(v_local, v_remote); }); y_tensor.get_thread_buffer()(i_0) = v_local; diff --git a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp index f65487ea6e..0cae4023b7 100644 --- a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp +++ b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp @@ -189,7 +189,9 @@ struct Reduce /// @note Requirements: /// - y_continous_dim % ThreadTile_N == 0 (for proper thread distribution) /// - input_strides[-1] == 1 (for contiguous memory access) - CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim, auto input_strides) + template + CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim, + InputStrides input_strides) { using S = typename Problem::BlockShape; diff --git a/test/ck_tile/reduce/test_reduce2d.cpp b/test/ck_tile/reduce/test_reduce2d.cpp index 4ce0b56ef3..821d0a6c3e 100644 --- a/test/ck_tile/reduce/test_reduce2d.cpp +++ b/test/ck_tile/reduce/test_reduce2d.cpp @@ -308,20 +308,8 @@ using TestConfig_F32_Max = std::tuple; -using TestConfig_F32_SquareAdd = std::tuple; - -using TestTypes = ::testing::Types; +using TestTypes = ::testing:: + Types; TYPED_TEST_SUITE(TestCkTileReduce, TestTypes);