[CK_TILE]enhance elementwise test (#2683)

* enhance elementwise

* fix ci issues

[ROCm/composable_kernel commit: b60af5bde9]
This commit is contained in:
joyeamd
2025-09-30 23:29:37 +08:00
committed by GitHub
parent 20333fd850
commit e17b20625e
2 changed files with 45 additions and 52 deletions

View File

@@ -104,24 +104,8 @@ struct ElementWiseKernel
template <typename... Ints>
CK_TILE_HOST static bool IsSupportedArgument(const ck_tile::tuple<Ints...>& input_sizes)
{
int total_elements = 1;
const auto kVectorM = Problem_::BlockShape::kVectorM;
apply([&](auto&&... args) { ((total_elements *= args), ...); }, input_sizes);
if((total_elements % kVectorM) != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conditions not met: total number of input elements (",
total_elements,
") should be multiple of the vectorization size (",
kVectorM,
")");
}
return false;
}
// when total elements % kVectorM != 0; should use Pad instead of unsupported
ignore = input_sizes;
return true;
}
};

View File

@@ -92,43 +92,52 @@ class TestCkTileElementwise : public ::testing::Test
YDataType* p_y_device = static_cast<YDataType*>(d_y_mem.GetDeviceBuffer());
auto run_elementwise_kernel = [&](auto has_remainder) {
constexpr bool kPad = decltype(has_remainder)::value;
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
ComputeDataType,
YDataType,
TestElementWiseShape,
ElementwiseOpType,
kPad>;
using Policy = ck_tile::ElementWiseDefaultPolicy;
ck_tile::ElementWiseKernel<Problem, Policy> ew_kernel;
ck_tile::index_t grid_size = (total_m_elements + TestElementWiseShape::kBlockM - 1) /
TestElementWiseShape::kBlockM;
dim3 grid(grid_size, 1, 1);
dim3 block = dim3(ew_kernel.BlockSize());
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::stream_config s{nullptr, false, 0}; // Default stream, no timing, no log
ck_tile::launch_kernel(s,
ck_tile::make_kernel<kBlockPerCu> // MinBlockPerCu
(ew_kernel,
grid,
block,
0, // actual shared memory
lens,
strides, // input strides
strides, // output strides
d_x_ptrs_tuple,
p_y_device));
};
// Problem and Policy
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
ComputeDataType,
YDataType,
TestElementWiseShape,
ElementwiseOpType>;
using Policy = ck_tile::ElementWiseDefaultPolicy;
ck_tile::ElementWiseKernel<Problem, Policy> ew_kernel;
// Launch configuration
ck_tile::index_t grid_size =
(total_m_elements + TestElementWiseShape::kBlockM - 1) / TestElementWiseShape::kBlockM;
dim3 grid(grid_size, 1, 1);
dim3 block = dim3(ew_kernel.BlockSize());
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::stream_config s{nullptr, false, 0}; // Default stream, no timing, no log
// Check if the kernel configuration is supported
if(!ew_kernel.IsSupportedArgument(lens))
using BaseProblem = ck_tile::ElementWisePipelineProblem<XDataType,
ComputeDataType,
YDataType,
TestElementWiseShape,
ElementwiseOpType>;
if(total_m_elements % BaseProblem::BlockShape::kVectorM)
{
throw std::runtime_error(
"The kernel configuration is not supported for the given input size.");
run_elementwise_kernel(std::true_type{});
}
else
{
run_elementwise_kernel(std::false_type{});
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<kBlockPerCu> // MinBlockPerCu
(ew_kernel,
grid,
block,
0, // actual shared memory
lens,
strides, // input strides
strides, // output strides
d_x_ptrs_tuple,
p_y_device));
d_y_mem.FromDevice(h_y.data());