[CK_TILE]enhance elementwise test (#2683)

* enhance elementwise

* fix ci issues
This commit is contained in:
joyeamd
2025-09-30 23:29:37 +08:00
committed by GitHub
parent e78a897ec0
commit b60af5bde9
2 changed files with 45 additions and 52 deletions

View File

@@ -92,43 +92,52 @@ class TestCkTileElementwise : public ::testing::Test
YDataType* p_y_device = static_cast<YDataType*>(d_y_mem.GetDeviceBuffer());
auto run_elementwise_kernel = [&](auto has_remainder) {
constexpr bool kPad = decltype(has_remainder)::value;
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
ComputeDataType,
YDataType,
TestElementWiseShape,
ElementwiseOpType,
kPad>;
using Policy = ck_tile::ElementWiseDefaultPolicy;
ck_tile::ElementWiseKernel<Problem, Policy> ew_kernel;
ck_tile::index_t grid_size = (total_m_elements + TestElementWiseShape::kBlockM - 1) /
TestElementWiseShape::kBlockM;
dim3 grid(grid_size, 1, 1);
dim3 block = dim3(ew_kernel.BlockSize());
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::stream_config s{nullptr, false, 0}; // Default stream, no timing, no log
ck_tile::launch_kernel(s,
ck_tile::make_kernel<kBlockPerCu> // MinBlockPerCu
(ew_kernel,
grid,
block,
0, // actual shared memory
lens,
strides, // input strides
strides, // output strides
d_x_ptrs_tuple,
p_y_device));
};
// Problem and Policy
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
ComputeDataType,
YDataType,
TestElementWiseShape,
ElementwiseOpType>;
using Policy = ck_tile::ElementWiseDefaultPolicy;
ck_tile::ElementWiseKernel<Problem, Policy> ew_kernel;
// Launch configuration
ck_tile::index_t grid_size =
(total_m_elements + TestElementWiseShape::kBlockM - 1) / TestElementWiseShape::kBlockM;
dim3 grid(grid_size, 1, 1);
dim3 block = dim3(ew_kernel.BlockSize());
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::stream_config s{nullptr, false, 0}; // Default stream, no timing, no log
// Check if the kernel configuration is supported
if(!ew_kernel.IsSupportedArgument(lens))
using BaseProblem = ck_tile::ElementWisePipelineProblem<XDataType,
ComputeDataType,
YDataType,
TestElementWiseShape,
ElementwiseOpType>;
if(total_m_elements % BaseProblem::BlockShape::kVectorM)
{
throw std::runtime_error(
"The kernel configuration is not supported for the given input size.");
run_elementwise_kernel(std::true_type{});
}
else
{
run_elementwise_kernel(std::false_type{});
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<kBlockPerCu> // MinBlockPerCu
(ew_kernel,
grid,
block,
0, // actual shared memory
lens,
strides, // input strides
strides, // output strides
d_x_ptrs_tuple,
p_y_device));
d_y_mem.FromDevice(h_y.data());