[ck_tile] refactor reduce kernel (#3257)

* refactor reduce kernel

- Rename Reduce kernel as per convention

- Move kept_dim and reduce_dims from runtime to compile-time parameters

- Update Reduce2dProblem template to include KeptDim, ReduceDims, and
Rank

- Remove IsSupportedArgument validation function as it's unnecessary.
Not using the GuaranteedLastDimensionVectorStride while making tensor
view or descriptor which removes the bounds enforced earlier. We still
calculate and use vector size.

- Update reduce example to demonstrate NCHW->NHW reduction with
non-contiguous support

- Update tests

Kernel now handles both contiguous and non-contiguous memory layout.

* fix compile errors
This commit is contained in:
Yashvardhan Agarwal
2025-12-17 21:46:08 +02:00
committed by GitHub
parent 92653168c2
commit ea10a78203
5 changed files with 89 additions and 130 deletions

View File

@@ -53,10 +53,16 @@ class TestCkTileReduce : public ::testing::Test
d_y_mem.ToDevice(h_y.data()); // Initialize device output buffer
// Problem and kernel setup
using Problem = ck_tile::
Reduce2dProblem<XDataType, ComputeDataType, YDataType, TestReduce2dShape, ReduceOpType>;
using Problem = ck_tile::Reduce2dProblem<XDataType,
ComputeDataType,
YDataType,
TestReduce2dShape,
ReduceOpType,
KeptDimSeq,
ReduceDimSeq,
InputDim>;
using Kernel = ck_tile::Reduce<Problem>;
using Kernel = ck_tile::ReduceKernel<Problem>;
// Launch configuration
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
@@ -75,13 +81,6 @@ class TestCkTileReduce : public ::testing::Test
auto input_shape_tuple = make_shape_tuple.template operator()<InputDim>(input_shape);
auto input_strides_tuple = make_shape_tuple.template operator()<InputDim>(input_strides);
if(!Kernel::IsSupportedArgument(
output_shape[output_shape.size() - 1],
input_strides_tuple)) // output tensor's continuous dimension
{
throw std::runtime_error("Wrong! Arguments not supported!\n");
}
ck_tile::launch_kernel(
ck_tile::stream_config{nullptr, false, 0},
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
@@ -91,9 +90,7 @@ class TestCkTileReduce : public ::testing::Test
static_cast<XDataType*>(d_x_mem.GetDeviceBuffer()),
static_cast<YDataType*>(d_y_mem.GetDeviceBuffer()),
input_shape_tuple,
input_strides_tuple,
kept_dims,
reduce_dims));
input_strides_tuple));
// Get results back
d_y_mem.FromDevice(h_y.data());