mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
[CK_TILE] Add indexing to pooling operator (Lwpck 3892) (#3013)
* Add indexing support to pooling operator
- Add IndexDataType template parameter to pooling problem and kernel
definitions
- Enable pooling kernel to output indices of selected elements during
max/absmax pooling
- Add overloaded operators for Max and AbsMax that track when values
change using bool changed parameter
- Support optional index buffer allocation and management in device
memory
- Modify BlockReduce2d classes to handle index tensors alongside value
tensors
- Add separate shared memory allocation for index data in cross-warp
reductions
- Create validate_pool_indices function to verify index correctness
- Modify pool3d.cpp example to demonstrate index output functionality
- Add tests for index output
* fixes
* Refactor BlockReduce2D functions to get rid auxiliary private types.
* comment resolutions and some changes to block_reduce2d
- index reference implementation improved
- reduce_operator.hpp cleanedup
- updated the block_reduce2d.hpp to have index calculation for
BlockReduce2dLinearCrossWarpSync as well
* conditionally used variable declaration improvement
- the conditionally used vairbales are used only when indexing is
enabled. To inform the compiler that they may be unused and declare them
with least size possible. This may allow it to be optimized compared to
the previous declarations
* comment resolutions
* lexical ordering of the indicies
- introduced accumulate methods that handle the intermediate steps if
needed to order the indexes
* add reduce_operator_accumulate.hpp to core.hpp
---------
Co-authored-by: Adam Osewski <Adam.Osewski@amd.com>
[ROCm/composable_kernel commit: 3052d7c9e6]
This commit is contained in:
committed by
GitHub
parent
9ad15a658c
commit
edea16ce14
@@ -17,6 +17,7 @@ struct PoolHostArgs
|
||||
|
||||
CK_TILE_HOST PoolHostArgs(const void* input_ptr_,
|
||||
void* output_ptr_,
|
||||
void* output_index_ptr_,
|
||||
TensorShape input_shape_,
|
||||
TensorShape output_shape_,
|
||||
TensorShape input_strides_,
|
||||
@@ -28,6 +29,7 @@ struct PoolHostArgs
|
||||
WindowShape input_right_pads_)
|
||||
: input_ptr(input_ptr_),
|
||||
output_ptr(output_ptr_),
|
||||
output_index_ptr(output_index_ptr_),
|
||||
input_shape(input_shape_),
|
||||
output_shape(output_shape_),
|
||||
input_strides(input_strides_),
|
||||
@@ -42,6 +44,7 @@ struct PoolHostArgs
|
||||
|
||||
const void* input_ptr;
|
||||
void* output_ptr;
|
||||
void* output_index_ptr;
|
||||
|
||||
TensorShape input_shape;
|
||||
TensorShape output_shape;
|
||||
@@ -60,6 +63,7 @@ struct PoolKernelArgs
|
||||
{
|
||||
const void* input_ptr;
|
||||
void* output_ptr;
|
||||
void* output_index_ptr;
|
||||
TensorShape input_shape;
|
||||
TensorShape output_shape;
|
||||
TensorShape input_strides;
|
||||
@@ -80,6 +84,7 @@ struct PoolKernel
|
||||
using InDataType = ck_tile::remove_cvref_t<typename Problem::InDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using OutDataType = ck_tile::remove_cvref_t<typename Problem::OutDataType>;
|
||||
using IndexDataType = ck_tile::remove_cvref_t<typename Problem::IndexDataType>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
@@ -205,7 +210,23 @@ struct PoolKernel
|
||||
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
|
||||
out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded);
|
||||
if constexpr(Problem::kOutputIndex)
|
||||
{
|
||||
auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<IndexDataType*>(kargs.output_index_ptr),
|
||||
out_desc.get_element_space_size(),
|
||||
IndexDataType(-1));
|
||||
const auto out_index_tensor_padded =
|
||||
tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
|
||||
out_index_buffer_view, out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a dummy tensor for the third element when index output is not needed
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
@@ -338,7 +359,23 @@ struct PoolKernel
|
||||
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
|
||||
out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded);
|
||||
if constexpr(Problem::kOutputIndex)
|
||||
{
|
||||
auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<IndexDataType*>(kargs.output_index_ptr),
|
||||
out_desc.get_element_space_size(),
|
||||
IndexDataType(-1));
|
||||
const auto out_index_tensor_padded =
|
||||
tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
|
||||
out_index_buffer_view, out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a dummy tensor for the third element when index output is not needed
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
@@ -354,7 +391,7 @@ struct PoolKernel
|
||||
const auto iM = get_block_id() * S::Block_M;
|
||||
|
||||
// Get tensors based on dimensionality
|
||||
auto [in_tensor_padded, out_tensor_padded] = [&]() {
|
||||
auto [in_tensor_padded, out_tensor_padded, out_index_tensor_padded] = [&]() {
|
||||
if constexpr(WindowShape::size() == 2)
|
||||
return MakeTensorView2D(kargs);
|
||||
else if constexpr(WindowShape::size() == 3)
|
||||
@@ -387,16 +424,57 @@ struct PoolKernel
|
||||
auto y_tile = block_reduce2d.template MakeYBlockTile<XTensorTile>();
|
||||
set_tile(y_tile, reduce_op.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
if constexpr(Problem::kOutputIndex)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
block_reduce2d(x_tile, y_tile, reduce_op);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
auto y_index_window =
|
||||
make_tile_window(out_index_tensor_padded, make_tuple(number<S::Block_M>{}), {iM});
|
||||
|
||||
block_reduce2d_sync(y_tile, reduce_op);
|
||||
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
auto y_index_tile =
|
||||
block_reduce2d.template MakeYIndexBlockTile<XTensorTile, IndexDataType>();
|
||||
set_tile(y_index_tile, IndexDataType(0));
|
||||
|
||||
// Main reduction loop - with index tracking
|
||||
for(int k_tile = amd_wave_read_first_lane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
auto index_calculator = [&](const auto& x_indices) {
|
||||
// Get global coordinates in the 2D matrix space (M, N)
|
||||
const auto global_M = x_indices.at(number<0>{}) + iM;
|
||||
const auto global_N = (k_tile * S::Block_N) + x_indices.at(number<1>{});
|
||||
return in_tensor_padded.get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(global_M, global_N));
|
||||
};
|
||||
|
||||
block_reduce2d(x_tile, y_tile, y_index_tile, reduce_op, index_calculator);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(y_tile, y_index_tile, reduce_op);
|
||||
if constexpr(Problem::kNeedCrossWarpSync)
|
||||
{
|
||||
__shared__ char smem_indices[Policy::template GetIndicesSmemSize<Problem>()];
|
||||
|
||||
block_reduce2d_cross_warp(y_tile, y_index_tile, smem, smem_indices, reduce_op);
|
||||
}
|
||||
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
store_tile(y_index_window, cast_tile<IndexDataType>(y_index_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Main reduction loop - without index tracking
|
||||
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
block_reduce2d(x_tile, y_tile, reduce_op);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(y_tile, reduce_op);
|
||||
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
|
||||
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Validates if the given arguments are supported by the pooling kernel.
|
||||
@@ -481,6 +559,7 @@ struct PoolKernel
|
||||
{
|
||||
return PoolKernelArgs<TensorShape, WindowShape>{host_args.input_ptr,
|
||||
host_args.output_ptr,
|
||||
host_args.output_index_ptr,
|
||||
host_args.input_shape,
|
||||
host_args.output_shape,
|
||||
host_args.input_strides,
|
||||
|
||||
Reference in New Issue
Block a user