From fbefd916f0dc768996cfbfb5b7895ed92141e54b Mon Sep 17 00:00:00 2001 From: joyeamd Date: Thu, 14 Aug 2025 06:21:46 +0800 Subject: [PATCH] [CK_TILE]fix elementwise example in gfx11/12 (#2676) * fix elementwise examples * improve the robust * fix ck_tile's elementwise test * update elementwise test [ROCm/composable_kernel commit: bcc38deff776b2bca6e228343046782dc85686c3] --- example/ck_tile/21_elementwise/elementwise_example.cpp | 2 +- .../ck_tile/21_elementwise/elementwise_example_add_4d.cpp | 2 +- .../21_elementwise/elementwise_example_transpose.cpp | 5 +++-- .../ck_tile/21_elementwise/elementwise_example_unary.cpp | 3 +-- .../ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp | 7 ++++--- test/ck_tile/elementwise/test_elementwise_1d.cpp | 5 ++--- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/example/ck_tile/21_elementwise/elementwise_example.cpp b/example/ck_tile/21_elementwise/elementwise_example.cpp index 4c501860fd..469345b46c 100644 --- a/example/ck_tile/21_elementwise/elementwise_example.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example.cpp @@ -113,7 +113,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // ElementWiseShape bundles these tiling parameters. // It calculates derived properties like threads per wavefront, repeats, vectorization and total // block size. - using Shape = ck_tile::ElementWiseShape; + using Shape = ck_tile::ElementWiseShape; // ElementWisePipelineProblem encapsulates all necessary information for the elementwise kernel: // - Data types (input, compute, output). diff --git a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp index f18a910813..4a031265c9 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp @@ -69,7 +69,7 @@ bool run(const ck_tile::ArgParser& arg_parser) using BlockWarps = ck_tile::sequence<1>; using WarpTile = ck_tile::sequence<256>; - using Shape = ck_tile::ElementWiseShape; + using Shape = ck_tile::ElementWiseShape; using Problem = ck_tile::ElementWisePipelineProblem; using WarpTile = ck_tile::sequence<64>; - using Shape = ck_tile::ElementWiseShape; + using Shape = ck_tile::ElementWiseShape; // Problem definition for a single input tensor using Problem = ck_tile::ElementWisePipelineProblem{}); + constexpr ck_tile::index_t kBlockSize = + ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{}); constexpr ck_tile::index_t kBlockPerCu = 1; constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; diff --git a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp index 147dfd3424..d83592a033 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp @@ -38,7 +38,6 @@ bool run(const ck_tile::ArgParser& arg_parser) using XDataType = DataType; using YDataType = DataType; - using ComputeDataType = float; using XElementwiseOperation = ck_tile::element_wise::UnarySquare; // 1. Initialize the input data on the host @@ -64,7 +63,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // will cover some part of blockTile) using WarpTile = ck_tile::sequence<64>; // How many elements are covered by a warp - using Shape = ck_tile::ElementWiseShape; + using Shape = ck_tile::ElementWiseShape; using Problem = ck_tile::ElementWisePipelineProblem{}); - static constexpr index_t kVectorM = 16 / sizeof(ComputeDataType); + static constexpr index_t kVectorM = + min(static_cast(16 / sizeof(ComputeDataType)), kWarpM / get_warp_size()); static constexpr index_t kWarpPerBlockM = BlockWarps::at(number<0>{}); - static constexpr index_t kThreadPerWarpM = kWarpM / kVectorM; + static constexpr index_t kThreadPerWarpM = get_warp_size(); - static constexpr index_t kRepeatM = kBlockM / (kWarpPerBlockM * kWarpM); + static constexpr index_t kRepeatM = kBlockM / (kWarpPerBlockM * kVectorM * kThreadPerWarpM); static constexpr index_t kBlockSize = ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); diff --git a/test/ck_tile/elementwise/test_elementwise_1d.cpp b/test/ck_tile/elementwise/test_elementwise_1d.cpp index 7013792335..9966c369be 100644 --- a/test/ck_tile/elementwise/test_elementwise_1d.cpp +++ b/test/ck_tile/elementwise/test_elementwise_1d.cpp @@ -53,7 +53,7 @@ class TestCkTileElementwise : public ::testing::Test using BlockTile_ = std::tuple_element_t<5, Tuple>; using WarpTile_ = std::tuple_element_t<6, Tuple>; using TestElementWiseShape = - ck_tile::ElementWiseShape; + ck_tile::ElementWiseShape; static constexpr int NumInputs = elementwise_op_traits::num_inputs; void RunTest(ck_tile::index_t total_m_elements) @@ -195,8 +195,7 @@ TYPED_TEST(TestCkTileElementwise, RunElementwise_1024) { this->RunTest(1024); } TYPED_TEST(TestCkTileElementwise, RunElementwise_513) { - EXPECT_THROW((this->RunTest(513)), - std::runtime_error); // Test with an input size that's not a multiple of kVectorM + this->RunTest(513); // Test with an input size that's not a multiple of kVectorM } TYPED_TEST(TestCkTileElementwise, RunElementwise_516)