mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-25 09:37:42 +00:00
* fix reduce2d - revret the combine_partial_results() chnages - remove auto from function def * clang-format
This commit is contained in:
committed by
GitHub
parent
1e1ee758fa
commit
191c62967b
@@ -35,8 +35,6 @@ struct Add
|
||||
|
||||
return type_convert<T>(y_ + x_);
|
||||
}
|
||||
|
||||
static constexpr bool requires_special_combine = false;
|
||||
};
|
||||
|
||||
struct SquareAdd
|
||||
@@ -64,28 +62,6 @@ struct SquareAdd
|
||||
float x_ = type_convert<float>(x);
|
||||
return type_convert<T>(y_ + (x_ * x_));
|
||||
}
|
||||
|
||||
// For combining partial results
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
CK_TILE_HOST_DEVICE constexpr T combine_partial_results(const T& partial1,
|
||||
const T& partial2) const
|
||||
{
|
||||
return partial1 + partial2; // Just add the partial sums, don't square again
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
CK_TILE_HOST_DEVICE constexpr T combine_partial_results(T& partial1, T& partial2) const
|
||||
{
|
||||
float partial1_ = type_convert<float>(partial1);
|
||||
float partial2_ = type_convert<float>(partial2);
|
||||
return type_convert<T>(partial1_ + partial2_);
|
||||
}
|
||||
|
||||
static constexpr bool requires_special_combine = true;
|
||||
};
|
||||
|
||||
struct Max
|
||||
@@ -109,8 +85,6 @@ struct Max
|
||||
{
|
||||
return max(y, x);
|
||||
}
|
||||
|
||||
static constexpr bool requires_special_combine = false;
|
||||
};
|
||||
|
||||
struct AbsMax
|
||||
@@ -134,8 +108,6 @@ struct AbsMax
|
||||
{
|
||||
return max(y, abs(x));
|
||||
}
|
||||
|
||||
static constexpr bool requires_special_combine = false;
|
||||
};
|
||||
|
||||
} // namespace ReduceOp
|
||||
|
||||
@@ -183,16 +183,7 @@ struct BlockReduce2dSync
|
||||
|
||||
// pull data from remote lane
|
||||
const auto v_remote = warp_shuffle(v_local, src_lane);
|
||||
|
||||
// For reduce, use combine_partial_results for operations that require it
|
||||
if constexpr(ReduceFunc::requires_special_combine)
|
||||
{
|
||||
v_local = reduce_func.combine_partial_results(v_local, v_remote);
|
||||
}
|
||||
else
|
||||
{
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
}
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -309,16 +300,7 @@ struct BlockReduce2dCrossWarpSync
|
||||
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
|
||||
constexpr auto i_1 = number<i_1_n1 + 1>{};
|
||||
const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
|
||||
|
||||
// For reduce, use combine_partial_results for operations that require it
|
||||
if constexpr(ReduceFunc::requires_special_combine)
|
||||
{
|
||||
v_local = reduce_func.combine_partial_results(v_local, v_remote);
|
||||
}
|
||||
else
|
||||
{
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
}
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
});
|
||||
|
||||
y_tensor.get_thread_buffer()(i_0) = v_local;
|
||||
|
||||
@@ -189,7 +189,9 @@ struct Reduce
|
||||
/// @note Requirements:
|
||||
/// - y_continous_dim % ThreadTile_N == 0 (for proper thread distribution)
|
||||
/// - input_strides[-1] == 1 (for contiguous memory access)
|
||||
CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim, auto input_strides)
|
||||
template <typename InputStrides>
|
||||
CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim,
|
||||
InputStrides input_strides)
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
|
||||
@@ -308,20 +308,8 @@ using TestConfig_F32_Max = std::tuple<float,
|
||||
Shape1_WarpTile,
|
||||
Shape1_ThreadTile>;
|
||||
|
||||
using TestConfig_F32_SquareAdd = std::tuple<float,
|
||||
float,
|
||||
float,
|
||||
ck_tile::ReduceOp::SquareAdd,
|
||||
Shape1_BlockWarps,
|
||||
Shape1_BlockTile,
|
||||
Shape1_WarpTile,
|
||||
Shape1_ThreadTile>;
|
||||
|
||||
using TestTypes = ::testing::Types<TestConfig_F32_Add,
|
||||
TestConfig_F16_Add,
|
||||
TestConfig_F32_CrossWarp,
|
||||
TestConfig_F32_Max,
|
||||
TestConfig_F32_SquareAdd>;
|
||||
using TestTypes = ::testing::
|
||||
Types<TestConfig_F32_Add, TestConfig_F16_Add, TestConfig_F32_CrossWarp, TestConfig_F32_Max>;
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileReduce, TestTypes);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user