mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-23 22:34:36 +00:00
Merge commit '3052d7c9e6972d5ea7d2225ab78e45554ba70efd' into develop
This commit is contained in:
@@ -17,6 +17,7 @@ struct PoolHostArgs
|
||||
|
||||
CK_TILE_HOST PoolHostArgs(const void* input_ptr_,
|
||||
void* output_ptr_,
|
||||
void* output_index_ptr_,
|
||||
TensorShape input_shape_,
|
||||
TensorShape output_shape_,
|
||||
TensorShape input_strides_,
|
||||
@@ -28,6 +29,7 @@ struct PoolHostArgs
|
||||
WindowShape input_right_pads_)
|
||||
: input_ptr(input_ptr_),
|
||||
output_ptr(output_ptr_),
|
||||
output_index_ptr(output_index_ptr_),
|
||||
input_shape(input_shape_),
|
||||
output_shape(output_shape_),
|
||||
input_strides(input_strides_),
|
||||
@@ -42,6 +44,7 @@ struct PoolHostArgs
|
||||
|
||||
const void* input_ptr;
|
||||
void* output_ptr;
|
||||
void* output_index_ptr;
|
||||
|
||||
TensorShape input_shape;
|
||||
TensorShape output_shape;
|
||||
@@ -60,6 +63,7 @@ struct PoolKernelArgs
|
||||
{
|
||||
const void* input_ptr;
|
||||
void* output_ptr;
|
||||
void* output_index_ptr;
|
||||
TensorShape input_shape;
|
||||
TensorShape output_shape;
|
||||
TensorShape input_strides;
|
||||
@@ -80,6 +84,7 @@ struct PoolKernel
|
||||
using InDataType = ck_tile::remove_cvref_t<typename Problem::InDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using OutDataType = ck_tile::remove_cvref_t<typename Problem::OutDataType>;
|
||||
using IndexDataType = ck_tile::remove_cvref_t<typename Problem::IndexDataType>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
@@ -205,7 +210,23 @@ struct PoolKernel
|
||||
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
|
||||
out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded);
|
||||
if constexpr(Problem::kOutputIndex)
|
||||
{
|
||||
auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<IndexDataType*>(kargs.output_index_ptr),
|
||||
out_desc.get_element_space_size(),
|
||||
IndexDataType(-1));
|
||||
const auto out_index_tensor_padded =
|
||||
tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
|
||||
out_index_buffer_view, out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a dummy tensor for the third element when index output is not needed
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
@@ -338,7 +359,23 @@ struct PoolKernel
|
||||
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
|
||||
out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded);
|
||||
if constexpr(Problem::kOutputIndex)
|
||||
{
|
||||
auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<IndexDataType*>(kargs.output_index_ptr),
|
||||
out_desc.get_element_space_size(),
|
||||
IndexDataType(-1));
|
||||
const auto out_index_tensor_padded =
|
||||
tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
|
||||
out_index_buffer_view, out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a dummy tensor for the third element when index output is not needed
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
@@ -354,7 +391,7 @@ struct PoolKernel
|
||||
const auto iM = get_block_id() * S::Block_M;
|
||||
|
||||
// Get tensors based on dimensionality
|
||||
auto [in_tensor_padded, out_tensor_padded] = [&]() {
|
||||
auto [in_tensor_padded, out_tensor_padded, out_index_tensor_padded] = [&]() {
|
||||
if constexpr(WindowShape::size() == 2)
|
||||
return MakeTensorView2D(kargs);
|
||||
else if constexpr(WindowShape::size() == 3)
|
||||
@@ -387,16 +424,57 @@ struct PoolKernel
|
||||
auto y_tile = block_reduce2d.template MakeYBlockTile<XTensorTile>();
|
||||
set_tile(y_tile, reduce_op.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
if constexpr(Problem::kOutputIndex)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
block_reduce2d(x_tile, y_tile, reduce_op);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
auto y_index_window =
|
||||
make_tile_window(out_index_tensor_padded, make_tuple(number<S::Block_M>{}), {iM});
|
||||
|
||||
block_reduce2d_sync(y_tile, reduce_op);
|
||||
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
auto y_index_tile =
|
||||
block_reduce2d.template MakeYIndexBlockTile<XTensorTile, IndexDataType>();
|
||||
set_tile(y_index_tile, IndexDataType(0));
|
||||
|
||||
// Main reduction loop - with index tracking
|
||||
for(int k_tile = amd_wave_read_first_lane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
auto index_calculator = [&](const auto& x_indices) {
|
||||
// Get global coordinates in the 2D matrix space (M, N)
|
||||
const auto global_M = x_indices.at(number<0>{}) + iM;
|
||||
const auto global_N = (k_tile * S::Block_N) + x_indices.at(number<1>{});
|
||||
return in_tensor_padded.get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(global_M, global_N));
|
||||
};
|
||||
|
||||
block_reduce2d(x_tile, y_tile, y_index_tile, reduce_op, index_calculator);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(y_tile, y_index_tile, reduce_op);
|
||||
if constexpr(Problem::kNeedCrossWarpSync)
|
||||
{
|
||||
__shared__ char smem_indices[Policy::template GetIndicesSmemSize<Problem>()];
|
||||
|
||||
block_reduce2d_cross_warp(y_tile, y_index_tile, smem, smem_indices, reduce_op);
|
||||
}
|
||||
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
store_tile(y_index_window, cast_tile<IndexDataType>(y_index_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Main reduction loop - without index tracking
|
||||
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
block_reduce2d(x_tile, y_tile, reduce_op);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(y_tile, reduce_op);
|
||||
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
|
||||
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Validates if the given arguments are supported by the pooling kernel.
|
||||
@@ -481,6 +559,7 @@ struct PoolKernel
|
||||
{
|
||||
return PoolKernelArgs<TensorShape, WindowShape>{host_args.input_ptr,
|
||||
host_args.output_ptr,
|
||||
host_args.output_index_ptr,
|
||||
host_args.input_shape,
|
||||
host_args.output_shape,
|
||||
host_args.input_strides,
|
||||
|
||||
Reference in New Issue
Block a user