Merge commit '3052d7c9e6972d5ea7d2225ab78e45554ba70efd' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-29 08:15:15 +00:00
parent e571490afc
commit 6f6c855c0e
13 changed files with 860 additions and 99 deletions

View File

@@ -17,6 +17,7 @@ struct PoolHostArgs
CK_TILE_HOST PoolHostArgs(const void* input_ptr_,
void* output_ptr_,
void* output_index_ptr_,
TensorShape input_shape_,
TensorShape output_shape_,
TensorShape input_strides_,
@@ -28,6 +29,7 @@ struct PoolHostArgs
WindowShape input_right_pads_)
: input_ptr(input_ptr_),
output_ptr(output_ptr_),
output_index_ptr(output_index_ptr_),
input_shape(input_shape_),
output_shape(output_shape_),
input_strides(input_strides_),
@@ -42,6 +44,7 @@ struct PoolHostArgs
const void* input_ptr;
void* output_ptr;
void* output_index_ptr;
TensorShape input_shape;
TensorShape output_shape;
@@ -60,6 +63,7 @@ struct PoolKernelArgs
{
const void* input_ptr;
void* output_ptr;
void* output_index_ptr;
TensorShape input_shape;
TensorShape output_shape;
TensorShape input_strides;
@@ -80,6 +84,7 @@ struct PoolKernel
using InDataType = ck_tile::remove_cvref_t<typename Problem::InDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using OutDataType = ck_tile::remove_cvref_t<typename Problem::OutDataType>;
using IndexDataType = ck_tile::remove_cvref_t<typename Problem::IndexDataType>;
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
@@ -205,7 +210,23 @@ struct PoolKernel
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
out_desc_padded};
return make_tuple(in_tensor_padded, out_tensor_padded);
if constexpr(Problem::kOutputIndex)
{
auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
static_cast<IndexDataType*>(kargs.output_index_ptr),
out_desc.get_element_space_size(),
IndexDataType(-1));
const auto out_index_tensor_padded =
tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
out_index_buffer_view, out_desc_padded};
return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
}
else
{
// Return a dummy tensor for the third element when index output is not needed
return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
}
}
template <typename TensorShape, typename WindowShape>
@@ -338,7 +359,23 @@ struct PoolKernel
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
out_desc_padded};
return make_tuple(in_tensor_padded, out_tensor_padded);
if constexpr(Problem::kOutputIndex)
{
auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
static_cast<IndexDataType*>(kargs.output_index_ptr),
out_desc.get_element_space_size(),
IndexDataType(-1));
const auto out_index_tensor_padded =
tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
out_index_buffer_view, out_desc_padded};
return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
}
else
{
// Return a dummy tensor for the third element when index output is not needed
return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
}
}
public:
@@ -354,7 +391,7 @@ struct PoolKernel
const auto iM = get_block_id() * S::Block_M;
// Get tensors based on dimensionality
auto [in_tensor_padded, out_tensor_padded] = [&]() {
auto [in_tensor_padded, out_tensor_padded, out_index_tensor_padded] = [&]() {
if constexpr(WindowShape::size() == 2)
return MakeTensorView2D(kargs);
else if constexpr(WindowShape::size() == 3)
@@ -387,16 +424,57 @@ struct PoolKernel
auto y_tile = block_reduce2d.template MakeYBlockTile<XTensorTile>();
set_tile(y_tile, reduce_op.template GetIdentityValue<ComputeDataType>());
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
if constexpr(Problem::kOutputIndex)
{
const auto x_tile = load_tile(x_window);
block_reduce2d(x_tile, y_tile, reduce_op);
move_tile_window(x_window, {0, S::Block_N});
}
auto y_index_window =
make_tile_window(out_index_tensor_padded, make_tuple(number<S::Block_M>{}), {iM});
block_reduce2d_sync(y_tile, reduce_op);
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
store_tile(y_window, cast_tile<OutDataType>(y_tile));
auto y_index_tile =
block_reduce2d.template MakeYIndexBlockTile<XTensorTile, IndexDataType>();
set_tile(y_index_tile, IndexDataType(0));
// Main reduction loop - with index tracking
for(int k_tile = amd_wave_read_first_lane(0); k_tile < num_k_tiles; ++k_tile)
{
const auto x_tile = load_tile(x_window);
auto index_calculator = [&](const auto& x_indices) {
// Get global coordinates in the 2D matrix space (M, N)
const auto global_M = x_indices.at(number<0>{}) + iM;
const auto global_N = (k_tile * S::Block_N) + x_indices.at(number<1>{});
return in_tensor_padded.get_tensor_descriptor().calculate_offset(
make_tuple(global_M, global_N));
};
block_reduce2d(x_tile, y_tile, y_index_tile, reduce_op, index_calculator);
move_tile_window(x_window, {0, S::Block_N});
}
block_reduce2d_sync(y_tile, y_index_tile, reduce_op);
if constexpr(Problem::kNeedCrossWarpSync)
{
__shared__ char smem_indices[Policy::template GetIndicesSmemSize<Problem>()];
block_reduce2d_cross_warp(y_tile, y_index_tile, smem, smem_indices, reduce_op);
}
store_tile(y_window, cast_tile<OutDataType>(y_tile));
store_tile(y_index_window, cast_tile<IndexDataType>(y_index_tile));
}
else
{
// Main reduction loop - without index tracking
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
{
const auto x_tile = load_tile(x_window);
block_reduce2d(x_tile, y_tile, reduce_op);
move_tile_window(x_window, {0, S::Block_N});
}
block_reduce2d_sync(y_tile, reduce_op);
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
store_tile(y_window, cast_tile<OutDataType>(y_tile));
}
}
/// @brief Validates if the given arguments are supported by the pooling kernel.
@@ -481,6 +559,7 @@ struct PoolKernel
{
return PoolKernelArgs<TensorShape, WindowShape>{host_args.input_ptr,
host_args.output_ptr,
host_args.output_index_ptr,
host_args.input_shape,
host_args.output_shape,
host_args.input_strides,