diff --git a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp index a31edbdd2b..b1e5e01777 100644 --- a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp +++ b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp @@ -104,24 +104,8 @@ struct ElementWiseKernel template CK_TILE_HOST static bool IsSupportedArgument(const ck_tile::tuple& input_sizes) { - int total_elements = 1; - const auto kVectorM = Problem_::BlockShape::kVectorM; - - apply([&](auto&&... args) { ((total_elements *= args), ...); }, input_sizes); - - if((total_elements % kVectorM) != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Conditions not met: total number of input elements (", - total_elements, - ") should be multiple of the vectorization size (", - kVectorM, - ")"); - } - return false; - } - + // when total elements % kVectorM != 0; should use Pad instead of unsupported + ignore = input_sizes; return true; } }; diff --git a/test/ck_tile/elementwise/test_elementwise_1d.cpp b/test/ck_tile/elementwise/test_elementwise_1d.cpp index 2eb2b506e8..7daba611e9 100644 --- a/test/ck_tile/elementwise/test_elementwise_1d.cpp +++ b/test/ck_tile/elementwise/test_elementwise_1d.cpp @@ -92,43 +92,52 @@ class TestCkTileElementwise : public ::testing::Test YDataType* p_y_device = static_cast(d_y_mem.GetDeviceBuffer()); + auto run_elementwise_kernel = [&](auto has_remainder) { + constexpr bool kPad = decltype(has_remainder)::value; + using Problem = ck_tile::ElementWisePipelineProblem; + using Policy = ck_tile::ElementWiseDefaultPolicy; + ck_tile::ElementWiseKernel ew_kernel; + + ck_tile::index_t grid_size = (total_m_elements + TestElementWiseShape::kBlockM - 1) / + TestElementWiseShape::kBlockM; + dim3 grid(grid_size, 1, 1); + dim3 block = dim3(ew_kernel.BlockSize()); + constexpr ck_tile::index_t kBlockPerCu = 1; + + ck_tile::stream_config s{nullptr, false, 0}; // Default stream, no timing, no log + + ck_tile::launch_kernel(s, + ck_tile::make_kernel // MinBlockPerCu + (ew_kernel, + grid, + block, + 0, // actual shared memory + lens, + strides, // input strides + strides, // output strides + d_x_ptrs_tuple, + p_y_device)); + }; + // Problem and Policy - using Problem = ck_tile::ElementWisePipelineProblem; - using Policy = ck_tile::ElementWiseDefaultPolicy; - - ck_tile::ElementWiseKernel ew_kernel; - - // Launch configuration - ck_tile::index_t grid_size = - (total_m_elements + TestElementWiseShape::kBlockM - 1) / TestElementWiseShape::kBlockM; - dim3 grid(grid_size, 1, 1); - dim3 block = dim3(ew_kernel.BlockSize()); - constexpr ck_tile::index_t kBlockPerCu = 1; - - ck_tile::stream_config s{nullptr, false, 0}; // Default stream, no timing, no log - - // Check if the kernel configuration is supported - if(!ew_kernel.IsSupportedArgument(lens)) + using BaseProblem = ck_tile::ElementWisePipelineProblem; + if(total_m_elements % BaseProblem::BlockShape::kVectorM) { - throw std::runtime_error( - "The kernel configuration is not supported for the given input size."); + run_elementwise_kernel(std::true_type{}); + } + else + { + run_elementwise_kernel(std::false_type{}); } - - ck_tile::launch_kernel(s, - ck_tile::make_kernel // MinBlockPerCu - (ew_kernel, - grid, - block, - 0, // actual shared memory - lens, - strides, // input strides - strides, // output strides - d_x_ptrs_tuple, - p_y_device)); d_y_mem.FromDevice(h_y.data());