mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[ck_tile] refactor reduce kernel (#3257)
* refactor reduce kernel - Rename Reduce kernel as per convention - Move kept_dim and reduce_dims from runtime to compile-time parameters - Update Reduce2dProblem template to include KeptDim, ReduceDims, and Rank - Remove IsSupportedArgument validation function as it's unnecessary. Not using the GuaranteedLastDimensionVectorStride while making tensor view or descriptor which removes the bounds enforced earlier. We still calculate and use vector size. - Update reduce example to demonstrate NCHW->NHW reduction with non-contiguous support - Update tests Kernel now handles both contiguous and non-contiguous memory layout. * fix compile errors
This commit is contained in:
committed by
GitHub
parent
92653168c2
commit
ea10a78203
@@ -53,10 +53,16 @@ class TestCkTileReduce : public ::testing::Test
|
||||
d_y_mem.ToDevice(h_y.data()); // Initialize device output buffer
|
||||
|
||||
// Problem and kernel setup
|
||||
using Problem = ck_tile::
|
||||
Reduce2dProblem<XDataType, ComputeDataType, YDataType, TestReduce2dShape, ReduceOpType>;
|
||||
using Problem = ck_tile::Reduce2dProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
TestReduce2dShape,
|
||||
ReduceOpType,
|
||||
KeptDimSeq,
|
||||
ReduceDimSeq,
|
||||
InputDim>;
|
||||
|
||||
using Kernel = ck_tile::Reduce<Problem>;
|
||||
using Kernel = ck_tile::ReduceKernel<Problem>;
|
||||
|
||||
// Launch configuration
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
@@ -75,13 +81,6 @@ class TestCkTileReduce : public ::testing::Test
|
||||
auto input_shape_tuple = make_shape_tuple.template operator()<InputDim>(input_shape);
|
||||
auto input_strides_tuple = make_shape_tuple.template operator()<InputDim>(input_strides);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(
|
||||
output_shape[output_shape.size() - 1],
|
||||
input_strides_tuple)) // output tensor's continuous dimension
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported!\n");
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, false, 0},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
@@ -91,9 +90,7 @@ class TestCkTileReduce : public ::testing::Test
|
||||
static_cast<XDataType*>(d_x_mem.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(d_y_mem.GetDeviceBuffer()),
|
||||
input_shape_tuple,
|
||||
input_strides_tuple,
|
||||
kept_dims,
|
||||
reduce_dims));
|
||||
input_strides_tuple));
|
||||
|
||||
// Get results back
|
||||
d_y_mem.FromDevice(h_y.data());
|
||||
|
||||
Reference in New Issue
Block a user