mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
[ck_tile] refactor reduce kernel (#3257)
* refactor reduce kernel
- Rename Reduce kernel as per convention
- Move kept_dim and reduce_dims from runtime to compile-time parameters
- Update Reduce2dProblem template to include KeptDim, ReduceDims, and
Rank
- Remove IsSupportedArgument validation function as it's unnecessary.
Not using the GuaranteedLastDimensionVectorStride while making tensor
view or descriptor which removes the bounds enforced earlier. We still
calculate and use vector size.
- Update reduce example to demonstrate NCHW->NHW reduction with
non-contiguous support
- Update tests
Kernel now handles both contiguous and non-contiguous memory layout.
* fix compile errors
[ROCm/composable_kernel commit: ea10a78203]
This commit is contained in:
committed by
GitHub
parent
fca34268d1
commit
d73a2287f3
@@ -53,10 +53,16 @@ class TestCkTileReduce : public ::testing::Test
|
||||
d_y_mem.ToDevice(h_y.data()); // Initialize device output buffer
|
||||
|
||||
// Problem and kernel setup
|
||||
using Problem = ck_tile::
|
||||
Reduce2dProblem<XDataType, ComputeDataType, YDataType, TestReduce2dShape, ReduceOpType>;
|
||||
using Problem = ck_tile::Reduce2dProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
TestReduce2dShape,
|
||||
ReduceOpType,
|
||||
KeptDimSeq,
|
||||
ReduceDimSeq,
|
||||
InputDim>;
|
||||
|
||||
using Kernel = ck_tile::Reduce<Problem>;
|
||||
using Kernel = ck_tile::ReduceKernel<Problem>;
|
||||
|
||||
// Launch configuration
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
@@ -75,13 +81,6 @@ class TestCkTileReduce : public ::testing::Test
|
||||
auto input_shape_tuple = make_shape_tuple.template operator()<InputDim>(input_shape);
|
||||
auto input_strides_tuple = make_shape_tuple.template operator()<InputDim>(input_strides);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(
|
||||
output_shape[output_shape.size() - 1],
|
||||
input_strides_tuple)) // output tensor's continuous dimension
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported!\n");
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, false, 0},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
@@ -91,9 +90,7 @@ class TestCkTileReduce : public ::testing::Test
|
||||
static_cast<XDataType*>(d_x_mem.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(d_y_mem.GetDeviceBuffer()),
|
||||
input_shape_tuple,
|
||||
input_strides_tuple,
|
||||
kept_dims,
|
||||
reduce_dims));
|
||||
input_strides_tuple));
|
||||
|
||||
// Get results back
|
||||
d_y_mem.FromDevice(h_y.data());
|
||||
|
||||
Reference in New Issue
Block a user