mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
[CK_TILE]enhance elementwise test (#2683)
* enhance elementwise * fix ci issues
This commit is contained in:
@@ -92,43 +92,52 @@ class TestCkTileElementwise : public ::testing::Test
|
||||
|
||||
YDataType* p_y_device = static_cast<YDataType*>(d_y_mem.GetDeviceBuffer());
|
||||
|
||||
auto run_elementwise_kernel = [&](auto has_remainder) {
|
||||
constexpr bool kPad = decltype(has_remainder)::value;
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
TestElementWiseShape,
|
||||
ElementwiseOpType,
|
||||
kPad>;
|
||||
using Policy = ck_tile::ElementWiseDefaultPolicy;
|
||||
ck_tile::ElementWiseKernel<Problem, Policy> ew_kernel;
|
||||
|
||||
ck_tile::index_t grid_size = (total_m_elements + TestElementWiseShape::kBlockM - 1) /
|
||||
TestElementWiseShape::kBlockM;
|
||||
dim3 grid(grid_size, 1, 1);
|
||||
dim3 block = dim3(ew_kernel.BlockSize());
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
ck_tile::stream_config s{nullptr, false, 0}; // Default stream, no timing, no log
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<kBlockPerCu> // MinBlockPerCu
|
||||
(ew_kernel,
|
||||
grid,
|
||||
block,
|
||||
0, // actual shared memory
|
||||
lens,
|
||||
strides, // input strides
|
||||
strides, // output strides
|
||||
d_x_ptrs_tuple,
|
||||
p_y_device));
|
||||
};
|
||||
|
||||
// Problem and Policy
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
TestElementWiseShape,
|
||||
ElementwiseOpType>;
|
||||
using Policy = ck_tile::ElementWiseDefaultPolicy;
|
||||
|
||||
ck_tile::ElementWiseKernel<Problem, Policy> ew_kernel;
|
||||
|
||||
// Launch configuration
|
||||
ck_tile::index_t grid_size =
|
||||
(total_m_elements + TestElementWiseShape::kBlockM - 1) / TestElementWiseShape::kBlockM;
|
||||
dim3 grid(grid_size, 1, 1);
|
||||
dim3 block = dim3(ew_kernel.BlockSize());
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
ck_tile::stream_config s{nullptr, false, 0}; // Default stream, no timing, no log
|
||||
|
||||
// Check if the kernel configuration is supported
|
||||
if(!ew_kernel.IsSupportedArgument(lens))
|
||||
using BaseProblem = ck_tile::ElementWisePipelineProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
TestElementWiseShape,
|
||||
ElementwiseOpType>;
|
||||
if(total_m_elements % BaseProblem::BlockShape::kVectorM)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The kernel configuration is not supported for the given input size.");
|
||||
run_elementwise_kernel(std::true_type{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_elementwise_kernel(std::false_type{});
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<kBlockPerCu> // MinBlockPerCu
|
||||
(ew_kernel,
|
||||
grid,
|
||||
block,
|
||||
0, // actual shared memory
|
||||
lens,
|
||||
strides, // input strides
|
||||
strides, // output strides
|
||||
d_x_ptrs_tuple,
|
||||
p_y_device));
|
||||
|
||||
d_y_mem.FromDevice(h_y.data());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user