Merge commit '191c62967bf05f58641725b88f038bea462fe651' into develop

This commit is contained in:
assistant-librarian[bot]
2025-08-11 13:24:17 +00:00
parent 63c6b9b93c
commit b2ec5dde0a
4 changed files with 7 additions and 63 deletions

View File

@@ -183,16 +183,7 @@ struct BlockReduce2dSync
// pull data from remote lane
const auto v_remote = warp_shuffle(v_local, src_lane);
// For reduce, use combine_partial_results for operations that require it
if constexpr(ReduceFunc::requires_special_combine)
{
v_local = reduce_func.combine_partial_results(v_local, v_remote);
}
else
{
v_local = reduce_func(v_local, v_remote);
}
v_local = reduce_func(v_local, v_remote);
});
}
});
@@ -309,16 +300,7 @@ struct BlockReduce2dCrossWarpSync
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
constexpr auto i_1 = number<i_1_n1 + 1>{};
const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
// For reduce, use combine_partial_results for operations that require it
if constexpr(ReduceFunc::requires_special_combine)
{
v_local = reduce_func.combine_partial_results(v_local, v_remote);
}
else
{
v_local = reduce_func(v_local, v_remote);
}
v_local = reduce_func(v_local, v_remote);
});
y_tensor.get_thread_buffer()(i_0) = v_local;

View File

@@ -189,7 +189,9 @@ struct Reduce
/// @note Requirements:
/// - y_continous_dim % ThreadTile_N == 0 (for proper thread distribution)
/// - input_strides[-1] == 1 (for contiguous memory access)
CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim, auto input_strides)
template <typename InputStrides>
CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim,
InputStrides input_strides)
{
using S = typename Problem::BlockShape;