From 1777ce3229e47d419e479bd345cfe7b7a8794fae Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Thu, 7 Aug 2025 15:45:27 +0300 Subject: [PATCH] [CK_TILE] Enable printing more structures in CK-Tile (#2443) * Add more printing to core cktile * Revert other changes in static encoding pattern * Refactor to using a free print() function * Remove loops and print just the containers * Print tuple with better formatting, fix sequence compilation * Add some tests for print utility * Add print utility header * Print for static_encoding_pattern * add buffer_view printing * Align vector_traits * Fix formatting * Lower-case enum strings Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com> * Remove empty comment lines * Fix test with lower-case too * Reduce repeated code in print tests, move helper function closer to type definition, test X&Y * Add test_print_common.hpp * add print.hpp in core.hpp --------- Co-authored-by: Aviral Goel Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com> Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> [ROCm/composable_kernel commit: ffdee5e774cf73c3dc35869259ae8f460f969f1b] --- include/ck_tile/core.hpp | 1 + .../core/algorithm/coordinate_transform.hpp | 419 ++++++++---------- .../algorithm/static_encoding_pattern.hpp | 48 ++ include/ck_tile/core/arch/arch.hpp | 15 + include/ck_tile/core/container/array.hpp | 20 +- include/ck_tile/core/container/map.hpp | 35 +- include/ck_tile/core/container/sequence.hpp | 28 +- include/ck_tile/core/container/tuple.hpp | 21 +- .../core/numeric/integral_constant.hpp | 8 +- include/ck_tile/core/numeric/vector_type.hpp | 4 +- include/ck_tile/core/tensor/buffer_view.hpp | 109 +---- .../ck_tile/core/tensor/tensor_adaptor.hpp | 65 +-- .../ck_tile/core/tensor/tensor_descriptor.hpp | 42 +- .../ck_tile/core/tensor/tile_distribution.hpp | 41 +- .../tensor/tile_distribution_encoding.hpp | 204 ++++----- include/ck_tile/core/utility/print.hpp | 76 ++++ test/ck_tile/CMakeLists.txt | 3 +- test/ck_tile/utility/CMakeLists.txt | 4 + test/ck_tile/utility/print/CMakeLists.txt | 8 + test/ck_tile/utility/print/README.md | 70 +++ .../utility/print/test_print_array.cpp | 59 +++ .../utility/print/test_print_basic_types.cpp | 76 ++++ .../utility/print/test_print_buffer_view.cpp | 78 ++++ .../utility/print/test_print_common.hpp | 25 ++ .../print/test_print_coordinate_transform.cpp | 83 ++++ .../utility/print/test_print_sequence.cpp | 45 ++ .../test_print_static_encoding_pattern.cpp | 89 ++++ .../utility/print/test_print_tuple.cpp | 66 +++ 28 files changed, 1211 insertions(+), 531 deletions(-) create mode 100644 include/ck_tile/core/utility/print.hpp create mode 100644 test/ck_tile/utility/CMakeLists.txt create mode 100644 test/ck_tile/utility/print/CMakeLists.txt create mode 100644 test/ck_tile/utility/print/README.md create mode 100644 test/ck_tile/utility/print/test_print_array.cpp create mode 100644 test/ck_tile/utility/print/test_print_basic_types.cpp create mode 100644 test/ck_tile/utility/print/test_print_buffer_view.cpp create mode 100644 test/ck_tile/utility/print/test_print_common.hpp create mode 100644 test/ck_tile/utility/print/test_print_coordinate_transform.cpp create mode 100644 test/ck_tile/utility/print/test_print_sequence.cpp create mode 100644 test/ck_tile/utility/print/test_print_static_encoding_pattern.cpp create mode 100644 test/ck_tile/utility/print/test_print_tuple.cpp diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 188cebaabc..c8945f03e9 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -74,6 +74,7 @@ #include "ck_tile/core/utility/literals.hpp" #include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/philox_rand.hpp" +#include "ck_tile/core/utility/print.hpp" #include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/reduce_operator.hpp" #include "ck_tile/core/utility/static_counter.hpp" diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index f7f9489f4c..7511413bba 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -9,6 +9,7 @@ #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/magic_div.hpp" +#include "ck_tile/core/utility/print.hpp" namespace ck_tile { @@ -139,20 +140,19 @@ struct pass_through : public base_transform<1, 1> { return make_tuple(low_vector_lengths, low_vector_strides); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("pass_through{"); - - // - printf("up_lengths_:"); - print(up_lengths_); - - // - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const pass_through& pt) +{ + printf("pass_through{"); + + printf("up_lengths_: "); + print(pt.get_upper_lengths()); + + printf("}"); +} + template ck_tile::is_known_at_compile_time::value && ck_tile::is_known_at_compile_time::value; } - - CK_TILE_HOST_DEVICE void print() const - { - printf("pad{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - printf(", "); - - // - printf("left_pad_length_: "); - print(left_pad_length_); - printf(", "); - - // - printf("right_pad_length_: "); - print(right_pad_length_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void +print(const pad& p) +{ + printf("pad{"); + printf("up_lengths_: "); + print(p.up_lengths_); + printf(", left_pad_length_: "); + print(p.left_pad_length_); + printf(", right_pad_length_: "); + print(p.right_pad_length_); + printf("}"); +} + template struct left_pad { @@ -330,24 +326,20 @@ struct left_pad // It's up to runtime to check the padding length should be multiple of vector length return make_tuple(low_vector_lengths, low_vector_strides); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("left_pad{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - printf(", "); - - // - printf("left_pad_length_: "); - print(left_pad_length_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void +print(const left_pad& lp) +{ + printf("left_pad{"); + printf("up_lengths_: "); + print(lp.up_lengths_); + printf(", left_pad_length_: "); + print(lp.left_pad_length_); + printf("}"); +} + template struct right_pad : public base_transform<1, 1> { @@ -430,24 +422,20 @@ struct right_pad : public base_transform<1, 1> // It's up to runtime to check the padding length should be multiple of vector length return make_tuple(low_vector_lengths, low_vector_strides); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("right_pad{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - printf(", "); - - // - printf("right_pad_length_: "); - print(right_pad_length_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void +print(const right_pad& rp) +{ + printf("right_pad{"); + printf("up_lengths_: "); + print(rp.up_lengths_); + printf(", right_pad_length_: "); + print(rp.right_pad_length_); + printf("}"); +} + // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] // UpLengths and Coefficients can be either of the followings: // 1) Tuple of index_t, which is known at run-time, or @@ -532,24 +520,19 @@ struct embed : public base_transform<1, UpLengths::size()> return ck_tile::is_known_at_compile_time::value && ck_tile::is_known_at_compile_time::value; } - - CK_TILE_HOST_DEVICE void print() const - { - printf("embed{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - printf(", "); - - // - printf("coefficients_: "); - print(coefficients_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const embed& e) +{ + printf("embed{"); + printf("up_lengths_: "); + print(e.up_lengths_); + printf(", coefficients_: "); + print(e.coefficients_); + printf("}"); +} + template struct lambda_merge_generate_MagicDivision_calculate_magic_divisor { @@ -699,24 +682,19 @@ struct merge_v2_magic_division : public base_transform return make_tuple(up_vector_lengths, up_vector_strides); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("merge_v2_magic_division{"); - - // - printf("low_lengths_ "); - print(low_lengths_); - printf(", "); - - // - printf("up_lengths_ "); - print(up_lengths_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const merge_v2_magic_division& m) +{ + printf("merge_v2_magic_division{"); + printf("low_lengths_: "); + print(m.low_lengths_); + printf(", up_lengths_: "); + print(m.up_lengths_); + printf("}"); +} + // Implementation of "merge" transformation primitive that uses division and mod. It is supposed to // be used for low_lengths that are known at compile time and are power of 2, otherwise performance // will be very bad @@ -830,29 +808,21 @@ struct merge_v3_division_mod : public base_transform return make_tuple(up_vector_lengths, up_vector_strides); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("Merge_v3_direct_division_mod{"); - - // - printf("low_lengths_ "); - print(low_lengths_); - printf(", "); - - // - printf("low_lengths_scan_ "); - print(low_lengths_scan_); - printf(", "); - - // - printf("up_lengths_ "); - print(up_lengths_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const merge_v3_division_mod& m) +{ + printf("merge_v3_division_mod{"); + printf("low_lengths_: "); + print(m.low_lengths_); + printf(", low_lengths_scan_: "); + print(m.low_lengths_scan_); + printf(", up_lengths_: "); + print(m.up_lengths_); + printf("}"); +} + template struct unmerge : public base_transform<1, UpLengths::size()> { @@ -958,24 +928,19 @@ struct unmerge : public base_transform<1, UpLengths::size()> return make_tuple(up_vector_lengths, up_vector_strides); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("unmerge{"); - - // - printf("up_lengths_"); - print(up_lengths_); - printf(", "); - - // - printf("up_lengths_scan_"); - print(up_lengths_scan_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const unmerge& u) +{ + printf("unmerge{"); + printf("up_lengths_: "); + print(u.up_lengths_); + printf(", up_lengths_scan_: "); + print(u.up_lengths_scan_); + printf("}"); +} + template struct freeze : public base_transform<1, 0> { @@ -1023,19 +988,17 @@ struct freeze : public base_transform<1, 0> { return ck_tile::is_known_at_compile_time::value; } - - CK_TILE_HOST_DEVICE void print() const - { - printf("freeze{"); - - // - printf("low_idx_: "); - print(low_idx_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const freeze& f) +{ + printf("freeze{"); + printf("low_idx_: "); + print(f.low_idx_); + printf("}"); +} + // insert a dangling upper dimension without lower dimension template struct insert : public base_transform<0, 1> @@ -1092,18 +1055,17 @@ struct insert : public base_transform<0, 1> { return ck_tile::is_known_at_compile_time::value; } - - CK_TILE_HOST_DEVICE void print() const - { - printf("insert{"); - - // - print(up_lengths_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const insert& i) +{ + printf("insert{"); + printf("up_lengths_: "); + print(i.up_lengths_); + printf("}"); +} + // replicate the original tensor and create a higher dimensional tensor template struct replicate : public base_transform<0, UpLengths::size()> @@ -1152,21 +1114,19 @@ struct replicate : public base_transform<0, UpLengths::size()> return ck_tile::is_known_at_compile_time::value; } - CK_TILE_HOST_DEVICE void print() const - { - printf("replicate{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - - printf("}"); - } - // UpLengths up_lengths_; }; +template +CK_TILE_HOST_DEVICE static void print(const replicate& r) +{ + printf("replicate{"); + printf("up_lengths_: "); + print(r.up_lengths_); + printf("}"); +} + template struct slice : public base_transform<1, 1> { @@ -1238,28 +1198,20 @@ struct slice : public base_transform<1, 1> ck_tile::is_known_at_compile_time::value && ck_tile::is_known_at_compile_time::value; } +}; - CK_TILE_HOST_DEVICE void print() const - { - printf("slice{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - printf(", "); - - // - printf("slice_begin_: "); - print(slice_begin_); - printf(", "); - - // - printf("slice_end_: "); - print(slice_end_); - - printf("}"); - } // namespace ck -}; // namespace ck +template +CK_TILE_HOST_DEVICE static void print(const slice& s) +{ + printf("slice{"); + printf("up_lengths_: "); + print(s.up_lengths_); + printf(", slice_begin_: "); + print(s.slice_begin_); + printf(", slice_end_: "); + print(s.slice_end_); + printf("}"); +} /* * \brief lower_idx = upper_idx % modulus. @@ -1328,19 +1280,19 @@ struct modulo : public base_transform<1, 1> { return ck_tile::is_known_at_compile_time::value; } - - CK_TILE_HOST_DEVICE void print() const - { - printf("Modulus{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const modulo& m) +{ + printf("modulo{"); + printf("modulus_: "); + print(m.modulus_); + printf(", up_lengths_: "); + print(m.up_lengths_); + printf("}"); +} + // 2D XOR, NOTE: "xor" is a keyword template struct xor_t : public base_transform<2, 2> @@ -1424,20 +1376,17 @@ struct xor_t : public base_transform<2, 2> return make_tuple(up_vector_lengths, up_vector_strides); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("xor_t{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - printf(", "); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const xor_t& x) +{ + printf("xor_t{"); + printf("up_lengths_: "); + print(x.up_lengths_); + printf("}"); +} + template struct offset : public base_transform<1, 1> { @@ -1509,24 +1458,19 @@ struct offset : public base_transform<1, 1> return ck_tile::is_known_at_compile_time::value && ck_tile::is_known_at_compile_time::value; } - - CK_TILE_HOST_DEVICE void print() const - { - printf("offset{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - printf(", "); - - // - printf("offset_length_: "); - print(offset_length_); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const offset& o) +{ + printf("offset{"); + printf("up_lengths_: "); + print(o.up_lengths_); + printf(", offset_length_: "); + print(o.offset_length_); + printf("}"); +} + template struct indexing : public base_transform<1, 1> { @@ -1595,20 +1539,19 @@ struct indexing : public base_transform<1, 1> return ck_tile::is_known_at_compile_time::value && IndexingAdaptor::is_known_at_compile_time(); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("embed{"); - - // - printf("up_lengths_: "); - print(up_lengths_); - printf(", "); - - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const indexing& i) +{ + printf("indexing{"); + printf("up_lengths_: "); + print(i.up_lengths_); + printf(", iadaptor_: "); + print(i.iadaptor_); + printf("}"); +} + //******************************************************************************************************* template diff --git a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp index 8a3de3e5e0..1f6c389090 100644 --- a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp +++ b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp @@ -77,6 +77,7 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/core/tensor/tile_distribution_encoding.hpp" +#include "ck_tile/core/utility/print.hpp" namespace ck_tile { @@ -317,4 +318,51 @@ struct TileDistributionEncodingPattern2D +CK_TILE_HOST_DEVICE void print(const TileDistributionEncodingPattern2D&) +{ + using PatternType = TileDistributionEncodingPattern2D; + + printf("TileDistributionEncodingPattern2D: ", + BlockSize, + YPerTile, + XPerTile, + VecSize, + tile_distribution_pattern_to_string(DistributionPattern)); + printf("{: <%d, %d, %d>, : <%d, %d>}\n", + PatternType::Y0, + PatternType::Y1, + PatternType::Y2, + PatternType::X0, + PatternType::X1); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 96df9d70f7..ab42ec8617 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -218,4 +218,19 @@ CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity() #endif } +/// Helper function to convert address space enum to string +CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_enum addr_space) +{ + switch(addr_space) + { + case address_space_enum::generic: return "generic"; + case address_space_enum::global: return "global"; + case address_space_enum::lds: return "lds"; + case address_space_enum::sgpr: return "sgpr"; + case address_space_enum::constant: return "constant"; + case address_space_enum::vgpr: return "vgpr"; + default: return "unknown"; + } +} + } // namespace ck_tile diff --git a/include/ck_tile/core/container/array.hpp b/include/ck_tile/core/container/array.hpp index 94aa40e278..352c645325 100644 --- a/include/ck_tile/core/container/array.hpp +++ b/include/ck_tile/core/container/array.hpp @@ -177,9 +177,27 @@ struct array CK_TILE_HOST_DEVICE constexpr array() {} CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; } CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v; }; - CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); } }; +template +CK_TILE_HOST_DEVICE static void print(const array& a) +{ + printf("array{size: %ld, data: [", static_cast(N)); + for(index_t i = 0; i < N; ++i) + { + if(i > 0) + printf(", "); + print(a[i]); + } + printf("]}"); +} + +template +CK_TILE_HOST_DEVICE static void print(const array&) +{ + printf("array{size: 0, data: []}"); +} + template struct vector_traits; diff --git a/include/ck_tile/core/container/map.hpp b/include/ck_tile/core/container/map.hpp index 87b180cafc..7697995c92 100644 --- a/include/ck_tile/core/container/map.hpp +++ b/include/ck_tile/core/container/map.hpp @@ -139,26 +139,21 @@ struct map // WARNING: needed by compiler for C++ range-based for loop only, don't use this function! CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; } - - CK_TILE_HOST_DEVICE void print() const - { - printf("map{size_: %d, ", size_); - // - printf("impl_: ["); - // - for(const auto& [k, d] : *this) - { - printf("{key: "); - print(k); - printf(", data: "); - print(d); - printf("}, "); - } - // - printf("]"); - // - printf("}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const map& m) +{ + printf("map{size_: %d, impl_: [", m.size_); + for(const auto& [k, d] : m) + { + printf("{key: "); + print(k); + printf(", data: "); + print(d); + printf("}, "); + } + printf("]}"); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index 94309dd5dd..905b32dd15 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -9,13 +9,10 @@ #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/type_traits.hpp" -#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/print.hpp" namespace ck_tile { -template -struct static_for; - template struct sequence; @@ -196,15 +193,24 @@ struct sequence { return sequence{}; } - - CK_TILE_HOST_DEVICE static void print() - { - printf("sequence{size: %d, data: [", size()); - ((printf("%d ", Is)), ...); - printf("]}"); - } }; +template +CK_TILE_HOST_DEVICE static void print(const sequence&) +{ + printf("sequence<"); + if constexpr(sizeof...(Is) > 0) + { + bool first = true; + (([&first](index_t value) { + printf("%s%d", first ? "" : ", ", value); + first = false; + }(Is)), + ...); + } + printf(">"); +} + namespace impl { template struct __integer_sequence; diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index 63d145d8b9..4c48b3d477 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -300,12 +300,29 @@ struct tuple : impl::tuple_base, T...> #undef TP_COM_ }; -template +template +CK_TILE_HOST_DEVICE void print(const tuple& t) +{ + printf("tuple<"); + if constexpr(sizeof...(T) > 0) + { + bool first = true; + static_for<0, sizeof...(T), 1>{}([&t, &first](auto i) { + if(!first) + printf(", "); + print(t.get(i)); + first = false; + }); + } + printf(">"); +} + +template struct vector_traits; // specialization for array template -struct vector_traits> +struct vector_traits, void> { using scalar_type = __type_pack_element<0, T...>; static constexpr index_t vector_size = sizeof...(T); diff --git a/include/ck_tile/core/numeric/integral_constant.hpp b/include/ck_tile/core/numeric/integral_constant.hpp index 33c24da8c5..2ba2fd10c6 100644 --- a/include/ck_tile/core/numeric/integral_constant.hpp +++ b/include/ck_tile/core/numeric/integral_constant.hpp @@ -19,14 +19,18 @@ struct constant CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; } }; +template +CK_TILE_HOST_DEVICE static void print(const constant&) +{ + printf("%ld", static_cast(v)); +} + template struct integral_constant : constant { using value_type = T; using type = integral_constant; // using injected-class-name static constexpr T value = v; - // constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; } - // constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } // }; template diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index b165275a8c..58bdb43b08 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -84,7 +84,7 @@ using ext_vector_t = typename impl::ext_vector::type; // by default, any type will result in a vector_size=1 with scalar_type=T traits. // ... unless we have other vector_traits specialization -template +template struct vector_traits { using scalar_type = @@ -94,7 +94,7 @@ struct vector_traits // specialization for ext_vector_type() template -struct vector_traits +struct vector_traits { using scalar_type = std::conditional_t, int8_t, T>; static constexpr index_t vector_size = N; diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 4b39773939..ca314a6abe 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -210,28 +210,6 @@ struct buffer_view(const_cast*>(p_data_))); - - // buffer_size_ - printf("buffer_size_: "); - print(buffer_size_); - printf(", "); - - // invalid_element_value_ - printf("invalid_element_value_: "); - print(invalid_element_value_); - - printf("}"); - } }; // Address Space: Global @@ -757,28 +735,6 @@ struct buffer_view(const_cast*>(p_data_))); - - // buffer_size_ - printf("buffer_size_: "); - print(buffer_size_); - printf(", "); - - // invalid_element_value_ - printf("invalid_element_value_: "); - print(invalid_element_value_); - - printf("}"); - } }; // Address Space: LDS @@ -1138,28 +1094,6 @@ struct buffer_view(const_cast*>(p_data_))); - - // buffer_size_ - printf("buffer_size_: "); - print(buffer_size_); - printf(", "); - - // invalid_element_value_ - printf("invalid_element_value_: "); - print(invalid_element_value_); - - printf("}"); - } }; // Address Space: Vgpr @@ -1313,28 +1247,6 @@ struct buffer_view(const_cast*>(p_data_))); - - // buffer_size_ - printf("buffer_size_: "); - print(buffer_size_); - printf(", "); - - // invalid_element_value_ - printf("invalid_element_value_: "); - print(invalid_element_value_); - - printf("}"); - } }; template +CK_TILE_HOST_DEVICE void print(const buffer_view& bv) +{ + printf("buffer_view{AddressSpace: %s, p_data_: %p, buffer_size_: ", + address_space_to_string(BufferAddressSpace), + static_cast(const_cast*>(bv.p_data_))); + print(bv.buffer_size_); + printf(", invalid_element_value_: "); + print(bv.invalid_element_value_); + printf("}"); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_adaptor.hpp b/include/ck_tile/core/tensor/tensor_adaptor.hpp index e2a6ae6555..ec5538d79c 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor.hpp @@ -305,42 +305,45 @@ struct tensor_adaptor get_container_subset(vector_strides, top_dims)); } - CK_TILE_HOST_DEVICE void print() const - { - printf("tensor_adaptor{"); - - // - printf("transforms: "); - print(transforms_); - printf(", "); - - // - printf("LowerDimensionHiddenIds: "); - print(LowerDimensionHiddenIdss{}); - printf(", "); - - // - printf("UpperDimensionHiddenIds: "); - print(UpperDimensionHiddenIdss{}); - printf(", "); - - // - printf("BottomDimensionHiddenIds: "); - print(BottomDimensionHiddenIds{}); - printf(", "); - - // - printf("TopDimensionHiddenIds: "); - print(TopDimensionHiddenIds{}); - - printf("}"); - } - private: Transforms transforms_; ElementSize element_size_; }; +template +CK_TILE_HOST_DEVICE static void print(const tensor_adaptor& adaptor) +{ + printf("tensor_adaptor{\n"); + printf(" transforms: ["); + print(adaptor.get_transforms()); + printf("],\n"); + + printf(" LowerDimensionHiddenIds: ["); + print(LowerDimensionHiddenIdss{}); + printf("],\n"); + + printf(" UpperDimensionHiddenIds: ["); + print(UpperDimensionHiddenIdss{}); + printf("],\n"); + + printf(" BottomDimensionHiddenIds: ["); + print(BottomDimensionHiddenIds{}); + printf("],\n"); + + // + printf(" TopDimensionHiddenIds: ["); + print(TopDimensionHiddenIds{}); + printf("]\n}\n"); +} + // Transforms: Tuple // LowerDimensionOldTopIdss: Tuple, ...> // UpperDimensionNewTopIdss: Tuple, ...> diff --git a/include/ck_tile/core/tensor/tensor_descriptor.hpp b/include/ck_tile/core/tensor/tensor_descriptor.hpp index 0c3e04f315..0e4787a2f1 100644 --- a/include/ck_tile/core/tensor/tensor_descriptor.hpp +++ b/include/ck_tile/core/tensor/tensor_descriptor.hpp @@ -140,25 +140,37 @@ struct tensor_descriptor : public tensor_adaptor(GuaranteedVectorStrides{})); } - CK_TILE_HOST_DEVICE void print() const - { - printf("tensor_descriptor{"); - - // tensor_adaptor - Base::print(); - printf(", "); - - // element_space_size_ - printf("element_space_size_: "); - print(element_space_size_); - - printf("}"); - } - // TODO make these private ElementSpaceSize element_space_size_; }; +template +CK_TILE_HOST_DEVICE static void print(const tensor_descriptor& descriptor) +{ + printf("tensor_descriptor{\n"); + // first print the tensor adaptor part of the descriptor using the base class print + print(static_cast(descriptor)); + printf("element_space_size_: %ld,\n", + static_cast(descriptor.get_element_space_size().value)); + printf("guaranteed_vector_lengths: "); + print(GuaranteedVectorLengths{}); + printf(",\nguaranteed_vector_strides: "); + print(GuaranteedVectorStrides{}); + printf("}\n}\n"); +} + template CK_TILE_HOST_DEVICE constexpr auto make_tensor_descriptor_from_adaptor(const Adaptor& adaptor, diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index 11e6b35c39..bc02ec74d2 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -228,24 +228,6 @@ struct tile_distribution { return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static(); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("tile_distribution{"); - // - printf("tile_distribution_encoding: "); - print(DstrEncode{}); - printf(", "); - // - printf("ps_ys_to_xs_: "); - print(ps_ys_to_xs_); - printf(", "); - // - printf("ys_to_d_: "); - print(ys_to_d_); - // - printf("}"); - } }; namespace detail { @@ -710,4 +692,27 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( } } // namespace detail + +// Free print function for tile_distribution +template +CK_TILE_HOST_DEVICE void print(const tile_distribution& distribution) +{ + printf("tile_distribution{"); + printf("tile_distribution_encoding: "); + print(StaticTileDistributionEncoding_{}); + printf(", "); + printf("ps_ys_to_xs_: "); + print(distribution.ps_ys_to_xs_); + printf(", "); + printf("ys_to_d_: "); + print(distribution.ys_to_d_); + printf("}\n"); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_distribution_encoding.hpp b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp index b380e7c9d8..90d1a2ccb2 100644 --- a/include/ck_tile/core/tensor/tile_distribution_encoding.hpp +++ b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp @@ -428,109 +428,7 @@ struct tile_distribution_encoding { return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum()); } - - CK_TILE_HOST_DEVICE void print() const - { - printf("tile_distribution_encoding::detail{"); - // - printf("ndim_rh_major_: "); - print(ndim_rh_major_); - printf(", "); - // - printf("ndim_span_major_: "); - print(ndim_span_major_); - printf(", "); - // - printf("ndims_rhs_minor_: "); - print(ndims_rhs_minor_); - printf(", "); - // - printf("ndim_rh_major_: "); - print(ndim_rh_major_); - printf(", "); - // - printf("max_ndim_rh_minor_: "); - print(max_ndim_rh_minor_); - printf(", "); - // - printf("rhs_lengthss_: "); - print(rhs_lengthss_); - printf(", "); - // - printf("ys_lengths_: "); - print(ys_lengths_); - printf(", "); - // - printf("rhs_major_minor_to_ys_: "); - print(rhs_major_minor_to_ys_); - printf(", "); - // - printf("ndims_span_minor_: "); - print(ndims_span_minor_); - printf(", "); - // - printf("max_ndim_span_minor_: "); - print(max_ndim_span_minor_); - printf(", "); - // - printf("ys_to_span_major_: "); - print(ys_to_span_major_); - printf(", "); - // - printf("ys_to_span_minor_: "); - print(ys_to_span_minor_); - printf(", "); - // - printf("distributed_spans_lengthss_: "); - print(distributed_spans_lengthss_); - printf(", "); - // - printf("ndims_distributed_spans_minor_: "); - print(ndims_distributed_spans_minor_); - printf(", "); - // - printf("ps_over_rs_derivative_: "); - print(ps_over_rs_derivative_); - // - printf("}"); - } }; - - CK_TILE_HOST_DEVICE void print() const - { - printf("tile_distribution_encoding{"); - // - printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY); - // - printf("rs_lengths_: "); - print(rs_lengths_); - printf(", "); - // - printf("hs_lengthss_: "); - print(hs_lengthss_); - printf(", "); - // - printf("ps_to_rhss_major_: "); - print(ps_to_rhss_major_); - printf(", "); - // - printf("ps_to_rhss_minor_: "); - print(ps_to_rhss_minor_); - printf(", "); - // - printf("ys_to_rhs_major_: "); - print(ys_to_rhs_major_); - printf(", "); - // - printf("ys_to_rhs_minor_: "); - print(ys_to_rhs_minor_); - printf(", "); - // - printf("detail: "); - print(detail{}); - // - printf("}"); - } }; template @@ -896,4 +794,106 @@ make_reduce_tile_distribution_encoding(InDstr, sequence reduce } } // namespace detail + +// Free print function for tile_distribution_encoding::detail +template +CK_TILE_HOST_DEVICE void +print(const typename tile_distribution_encoding::detail& detail_obj) +{ + printf("tile_distribution_encoding::detail{"); + printf("ndim_rh_major_: "); + print(detail_obj.ndim_rh_major_); + printf(", "); + printf("ndim_span_major_: "); + print(detail_obj.ndim_span_major_); + printf(", "); + printf("ndims_rhs_minor_: "); + print(detail_obj.ndims_rhs_minor_); + printf(", "); + printf("ndim_rh_major_: "); + print(detail_obj.ndim_rh_major_); + printf(", "); + printf("max_ndim_rh_minor_: "); + print(detail_obj.max_ndim_rh_minor_); + printf(", "); + printf("rhs_lengthss_: "); + print(detail_obj.rhs_lengthss_); + printf(", "); + printf("ys_lengths_: "); + print(detail_obj.ys_lengths_); + printf(", "); + printf("rhs_major_minor_to_ys_: "); + print(detail_obj.rhs_major_minor_to_ys_); + printf(", "); + printf("ndims_span_minor_: "); + print(detail_obj.ndims_span_minor_); + printf(", "); + printf("max_ndim_span_minor_: "); + print(detail_obj.max_ndim_span_minor_); + printf(", "); + printf("ys_to_span_major_: "); + print(detail_obj.ys_to_span_major_); + printf(", "); + printf("ys_to_span_minor_: "); + print(detail_obj.ys_to_span_minor_); + printf(", "); + printf("distributed_spans_lengthss_: "); + print(detail_obj.distributed_spans_lengthss_); + printf(", "); + printf("ndims_distributed_spans_minor_: "); + print(detail_obj.ndims_distributed_spans_minor_); + printf(", "); + printf("ps_over_rs_derivative_: "); + print(detail_obj.ps_over_rs_derivative_); + printf("}"); +} + +// Free print function for tile_distribution_encoding +template +CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding& encoding) +{ + printf("tile_distribution_encoding{"); + + printf("NDimX: %d, NDimP: %d, NDimY: %d, ", encoding.NDimX, encoding.NDimP, encoding.NDimY); + printf("rs_lengths_: "); + print(encoding.rs_lengths_); + printf(", "); + printf("hs_lengthss_: "); + print(encoding.hs_lengthss_); + printf(", "); + printf("ps_to_rhss_major_: "); + print(encoding.ps_to_rhss_major_); + printf(", "); + printf("ps_to_rhss_minor_: "); + print(encoding.ps_to_rhss_minor_); + printf(", "); + printf("ys_to_rhs_major_: "); + print(encoding.ys_to_rhs_major_); + printf(", "); + printf("ys_to_rhs_minor_: "); + print(encoding.ys_to_rhs_minor_); + printf(", "); + printf("}"); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/utility/print.hpp b/include/ck_tile/core/utility/print.hpp new file mode 100644 index 0000000000..04635959af --- /dev/null +++ b/include/ck_tile/core/utility/print.hpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" + +namespace ck_tile { + +/// Declare a ck_tile::print() interface that gets specialized in each header file for types that +/// can be printed. +template +CK_TILE_HOST_DEVICE void print(const T&) +{ + static_assert(sizeof(T) == 0, + "No print implementation available for this type. Please specialize " + "ck_tile::print for your type."); +} + +/// Specialization for int +template <> +CK_TILE_HOST_DEVICE void print(const int& value) +{ + printf("%d", value); +} + +/// Specialization for float +template <> +CK_TILE_HOST_DEVICE void print(const float& value) +{ + printf("%f", value); +} + +/// Specialization for double +template <> +CK_TILE_HOST_DEVICE void print(const double& value) +{ + printf("%f", value); +} + +/// Specialization for long +template <> +CK_TILE_HOST_DEVICE void print(const long& value) +{ + printf("%ld", value); +} + +/// Specialization for unsigned int +template <> +CK_TILE_HOST_DEVICE void print(const unsigned int& value) +{ + printf("%u", value); +} + +/// Specialization for char +template <> +CK_TILE_HOST_DEVICE void print(const char& value) +{ + printf("%c", value); +} + +/// Specialization for array +template +CK_TILE_HOST_DEVICE void print(const T (&value)[N]) +{ + printf("["); + for(size_t i = 0; i < N; ++i) + { + if(i > 0) + printf(", "); + print(value[i]); // Recursively call print for each element + } + printf("]"); +} + +} // namespace ck_tile diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 9a1df56208..374e5b4990 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -21,4 +21,5 @@ add_subdirectory(add_rmsnorm2d_rdquant) # add_subdirectory(layernorm2d) # add_subdirectory(rmsnorm2d) add_subdirectory(gemm_block_scale) -add_subdirectory(reduce) \ No newline at end of file +add_subdirectory(utility) +add_subdirectory(reduce) diff --git a/test/ck_tile/utility/CMakeLists.txt b/test/ck_tile/utility/CMakeLists.txt new file mode 100644 index 0000000000..c57cafca5a --- /dev/null +++ b/test/ck_tile/utility/CMakeLists.txt @@ -0,0 +1,4 @@ +message("-- Adding: test/ck_tile/utility/") + +# Add print tests +add_subdirectory(print) diff --git a/test/ck_tile/utility/print/CMakeLists.txt b/test/ck_tile/utility/print/CMakeLists.txt new file mode 100644 index 0000000000..5300dd20ca --- /dev/null +++ b/test/ck_tile/utility/print/CMakeLists.txt @@ -0,0 +1,8 @@ +# Print utility tests +add_gtest_executable(test_print_sequence test_print_sequence.cpp) +add_gtest_executable(test_print_array test_print_array.cpp) +add_gtest_executable(test_print_tuple test_print_tuple.cpp) +add_gtest_executable(test_print_coordinate_transform test_print_coordinate_transform.cpp) +add_gtest_executable(test_print_static_encoding_pattern test_print_static_encoding_pattern.cpp) +add_gtest_executable(test_print_buffer_view test_print_buffer_view.cpp) +add_gtest_executable(test_print_basic_types test_print_basic_types.cpp) diff --git a/test/ck_tile/utility/print/README.md b/test/ck_tile/utility/print/README.md new file mode 100644 index 0000000000..558c6faee4 --- /dev/null +++ b/test/ck_tile/utility/print/README.md @@ -0,0 +1,70 @@ +# Print Function Tests + +This directory contains unit tests for testing the print functionality of various data structures and coordinate transformations in the composable_kernel library. + +## Tests Included + +### test_print_sequence.cpp +Tests the print functionality for `sequence<...>` containers: +- Simple sequences with multiple elements +- Single element sequences +- Empty sequences +- Longer sequences + +### test_print_array.cpp +Tests the print functionality for `array` containers: +- Arrays with integer values +- Single element arrays +- Empty arrays (size 0) +- Arrays with floating point values + +### test_print_tuple.cpp +Tests the print functionality for `tuple<...>` containers: +- Simple tuples with numbers +- Single element tuples +- Empty tuples +- Mixed type tuples + +### test_print_coordinate_transform.cpp +Tests the print functionality for coordinate transformation structures: +- `pass_through` transform +- `embed` transform +- `merge` transform +- `unmerge` transform +- `freeze` transform + +## Testing Approach + +All tests use Google Test's `CaptureStdout()` functionality to capture the output from print functions and verify the formatting: + +```cpp +testing::internal::CaptureStdout(); +print(object); +std::string output = testing::internal::GetCapturedStdout(); +EXPECT_EQ(output, "expected_format"); +``` + +This approach enables testing of print function output without affecting the console during test execution. + +## Building and Running + +The tests are integrated into the CMake build system. To build and run the print tests: + +```bash +# Build the specific test +make test_print_sequence + +# Run the test +./test_print_sequence + +# Or run all print tests using CTest +ctest -R "test_print" +``` + +## Adding New Tests + +To add tests for new data structures: + +1. Create a new test file: `test_print_.cpp` +2. Follow the existing pattern using `CaptureStdout()` +3. Add the test executable to `CMakeLists.txt` diff --git a/test/ck_tile/utility/print/test_print_array.cpp b/test/ck_tile/utility/print/test_print_array.cpp new file mode 100644 index 0000000000..2fe9bc2a0c --- /dev/null +++ b/test/ck_tile/utility/print/test_print_array.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_print_common.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/utility/print.hpp" + +namespace ck_tile { + +class PrintArrayTest : public PrintTest +{ +}; + +TEST_F(PrintArrayTest, PrintIntArray) +{ + // Test printing array + array arr{10, 20, 30}; + + std::string output = CapturePrintOutput(arr); + + // The expected format should match the array print function implementation + EXPECT_EQ(output, "array{size: 3, data: [10, 20, 30]}"); +} + +TEST_F(PrintArrayTest, PrintSingleElementArray) +{ + // Test printing array + array arr{42}; + + std::string output = CapturePrintOutput(arr); + + EXPECT_EQ(output, "array{size: 1, data: [42]}"); +} + +TEST_F(PrintArrayTest, PrintEmptyArray) +{ + // Test printing array (empty array) + array arr{}; + + std::string output = CapturePrintOutput(arr); + + EXPECT_EQ(output, "array{size: 0, data: []}"); +} + +TEST_F(PrintArrayTest, PrintFloatArray) +{ + // Test printing array with float values + array arr{3.14f, 2.71f}; + + std::string output = CapturePrintOutput(arr); + + // Note: float printing format may vary, so we'll test for basic structure + EXPECT_TRUE(output.find("array{size: 2, data: [") == 0); + EXPECT_TRUE(output.find("3.14") != std::string::npos); + EXPECT_TRUE(output.find("2.71") != std::string::npos); + EXPECT_TRUE(output.find("]}") == output.length() - 2); +} + +} // namespace ck_tile diff --git a/test/ck_tile/utility/print/test_print_basic_types.cpp b/test/ck_tile/utility/print/test_print_basic_types.cpp new file mode 100644 index 0000000000..7a26b6371a --- /dev/null +++ b/test/ck_tile/utility/print/test_print_basic_types.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_print_common.hpp" +#include "ck_tile/core/utility/print.hpp" + +namespace ck_tile { + +class PrintBasicTypesTest : public PrintTest +{ +}; + +TEST_F(PrintBasicTypesTest, PrintIntArray) +{ + int arr[4] = {1, 2, 3, 4}; + + std::string output = CapturePrintOutput(arr); + + EXPECT_EQ(output, "[1, 2, 3, 4]"); +} + +TEST_F(PrintBasicTypesTest, PrintFloatArray) +{ + float arr[3] = {1.5f, 2.5f, 3.5f}; + + std::string output = CapturePrintOutput(arr); + + // Note: floating point formatting may vary, so we check for key elements + EXPECT_TRUE(output.find("[") == 0); + EXPECT_TRUE(output.find("1.5") != std::string::npos); + EXPECT_TRUE(output.find("2.5") != std::string::npos); + EXPECT_TRUE(output.find("3.5") != std::string::npos); + EXPECT_TRUE(output.back() == ']'); + EXPECT_TRUE(output.find(", ") != std::string::npos); +} + +TEST_F(PrintBasicTypesTest, PrintDoubleArray) +{ + double arr[2] = {10.123, 20.456}; + + std::string output = CapturePrintOutput(arr); + + EXPECT_TRUE(output.find("[") == 0); + EXPECT_TRUE(output.find("10.123") != std::string::npos); + EXPECT_TRUE(output.find("20.456") != std::string::npos); + EXPECT_TRUE(output.back() == ']'); +} + +TEST_F(PrintBasicTypesTest, PrintUnsignedIntArray) +{ + unsigned int arr[3] = {100u, 200u, 300u}; + + std::string output = CapturePrintOutput(arr); + + EXPECT_EQ(output, "[100, 200, 300]"); +} + +TEST_F(PrintBasicTypesTest, PrintCharArray) +{ + char arr[5] = {'a', 'b', 'c', 'd', 'e'}; + + std::string output = CapturePrintOutput(arr); + + EXPECT_EQ(output, "[a, b, c, d, e]"); +} + +TEST_F(PrintBasicTypesTest, PrintSingleElementArray) +{ + int arr[1] = {42}; + + std::string output = CapturePrintOutput(arr); + + EXPECT_EQ(output, "[42]"); +} + +} // namespace ck_tile diff --git a/test/ck_tile/utility/print/test_print_buffer_view.cpp b/test/ck_tile/utility/print/test_print_buffer_view.cpp new file mode 100644 index 0000000000..66668a2103 --- /dev/null +++ b/test/ck_tile/utility/print/test_print_buffer_view.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_print_common.hpp" +#include "ck_tile/core/tensor/buffer_view.hpp" +#include "ck_tile/core/utility/print.hpp" + +namespace ck_tile { + +class PrintBufferViewTest : public PrintTest +{ +}; + +TEST_F(PrintBufferViewTest, PrintGenericBufferView) +{ + // Test printing generic address space buffer_view + float data[4] = {100.f, 200.f, 300.f, 400.f}; + auto bv = make_buffer_view(&data, 4); + + std::string output = CapturePrintOutput(bv); + + // Verify the output contains expected information + EXPECT_TRUE(output.find("buffer_view{AddressSpace: generic") != std::string::npos); + EXPECT_TRUE(output.find("p_data_:") != std::string::npos); + EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos); + EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos); + EXPECT_TRUE(output.find("}") != std::string::npos); +} + +TEST_F(PrintBufferViewTest, PrintGlobalBufferView) +{ + // Test printing global address space buffer_view + float data[4] = {100.f, 200.f, 300.f, 400.f}; + auto bv = make_buffer_view(&data, 4); + + std::string output = CapturePrintOutput(bv); + + // Verify the output contains expected information + EXPECT_TRUE(output.find("buffer_view{AddressSpace: global") != std::string::npos); + EXPECT_TRUE(output.find("p_data_:") != std::string::npos); + EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos); + EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos); + EXPECT_TRUE(output.find("}") != std::string::npos); +} + +TEST_F(PrintBufferViewTest, PrintLdsBufferView) +{ + // Test printing LDS address space buffer_view + float data[4] = {100.f, 200.f, 300.f, 400.f}; + auto bv = make_buffer_view(data, 4); + + std::string output = CapturePrintOutput(bv); + + // Verify the output contains expected information + EXPECT_TRUE(output.find("buffer_view{AddressSpace: lds") != std::string::npos); + EXPECT_TRUE(output.find("p_data_:") != std::string::npos); + EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos); + EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos); + EXPECT_TRUE(output.find("}") != std::string::npos); +} + +TEST_F(PrintBufferViewTest, PrintVgprBufferView) +{ + // Test printing VGPR address space buffer_view + float data[4] = {1.5f, 2.5f, 3.5f, 4.5f}; + auto bv = make_buffer_view(data, 4); + + std::string output = CapturePrintOutput(bv); + + // Verify the output contains expected information + EXPECT_TRUE(output.find("buffer_view{AddressSpace: vgpr") != std::string::npos); + EXPECT_TRUE(output.find("p_data_:") != std::string::npos); + EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos); + EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos); + EXPECT_TRUE(output.find("}") != std::string::npos); +} + +} // namespace ck_tile diff --git a/test/ck_tile/utility/print/test_print_common.hpp b/test/ck_tile/utility/print/test_print_common.hpp new file mode 100644 index 0000000000..3ba2270802 --- /dev/null +++ b/test/ck_tile/utility/print/test_print_common.hpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core/utility/print.hpp" + +class PrintTest : public ::testing::Test +{ + protected: + void SetUp() override {} + void TearDown() override {} + // Helper function to capture and return the output of a print function + template + std::string CapturePrintOutput(const T& type) + { + using namespace ck_tile; + testing::internal::CaptureStdout(); + print(type); + return testing::internal::GetCapturedStdout(); + } +}; diff --git a/test/ck_tile/utility/print/test_print_coordinate_transform.cpp b/test/ck_tile/utility/print/test_print_coordinate_transform.cpp new file mode 100644 index 0000000000..639b113eb7 --- /dev/null +++ b/test/ck_tile/utility/print/test_print_coordinate_transform.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_print_common.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/utility/print.hpp" + +namespace ck_tile { + +class PrintCoordinateTransformTest : public PrintTest +{ +}; + +TEST_F(PrintCoordinateTransformTest, PrintPassThrough) +{ + // Test printing pass_through transform + auto pt = make_pass_through_transform(number<32>{}); + + std::string output = CapturePrintOutput(pt); + + // Verify it contains the pass_through identifier and some structure + EXPECT_TRUE(output.find("pass_through{") == 0); + EXPECT_TRUE(output.find("up_lengths_") != std::string::npos); + EXPECT_TRUE(output.back() == '}'); +} + +TEST_F(PrintCoordinateTransformTest, PrintEmbed) +{ + // Test printing embed transform + auto embed_transform = make_embed_transform(make_tuple(number<4>{}, number<8>{}), + make_tuple(number<1>{}, number<4>{})); + + std::string output = CapturePrintOutput(embed_transform); + + // Verify it contains the embed identifier and key fields + EXPECT_TRUE(output.find("embed{") == 0); + EXPECT_TRUE(output.find("up_lengths_") != std::string::npos); + EXPECT_TRUE(output.find("coefficients_") != std::string::npos); + EXPECT_TRUE(output.back() == '}'); +} + +TEST_F(PrintCoordinateTransformTest, PrintMerge) +{ + // Test printing merge transform + auto merge_transform = make_merge_transform(make_tuple(number<4>{}, number<8>{})); + + std::string output = CapturePrintOutput(merge_transform); + + // Verify it contains merge identifier and key fields + EXPECT_TRUE(output.find("merge") == + 0); // Could be merge_v2_magic_division or merge_v3_division_mod + EXPECT_TRUE(output.find("low_lengths_") != std::string::npos || + output.find("up_lengths_") != std::string::npos); + EXPECT_TRUE(output.back() == '}'); +} + +TEST_F(PrintCoordinateTransformTest, PrintUnmerge) +{ + // Test printing unmerge transform + auto unmerge_transform = make_unmerge_transform(make_tuple(number<4>{}, number<8>{})); + + std::string output = CapturePrintOutput(unmerge_transform); + + // Verify it contains the unmerge identifier and key fields + EXPECT_TRUE(output.find("unmerge{") == 0); + EXPECT_TRUE(output.find("up_lengths_") != std::string::npos); + EXPECT_TRUE(output.back() == '}'); +} + +TEST_F(PrintCoordinateTransformTest, PrintFreeze) +{ + // Test printing freeze transform + auto freeze_transform = make_freeze_transform(number<5>{}); + + std::string output = CapturePrintOutput(freeze_transform); + + // Verify it contains the freeze identifier and key fields + EXPECT_TRUE(output.find("freeze{") == 0); + EXPECT_TRUE(output.find("low_idx_") != std::string::npos); + EXPECT_TRUE(output.back() == '}'); +} + +} // namespace ck_tile diff --git a/test/ck_tile/utility/print/test_print_sequence.cpp b/test/ck_tile/utility/print/test_print_sequence.cpp new file mode 100644 index 0000000000..e73a9f7e33 --- /dev/null +++ b/test/ck_tile/utility/print/test_print_sequence.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_print_common.hpp" +#include "ck_tile/core/utility/print.hpp" +#include "ck_tile/core/container/sequence.hpp" + +namespace ck_tile { + +class PrintSequenceTest : public PrintTest +{ +}; + +TEST_F(PrintSequenceTest, PrintSimpleSequence) +{ + // Test printing sequence<1, 5, 8> + constexpr auto seq = sequence<1, 5, 8>{}; + + std::string output = CapturePrintOutput(seq); + + // Verify the output format + EXPECT_EQ(output, "sequence<1, 5, 8>"); +} + +TEST_F(PrintSequenceTest, PrintSingleElementSequence) +{ + // Test printing sequence<42> + constexpr auto seq = sequence<42>{}; + + std::string output = CapturePrintOutput(seq); + + EXPECT_EQ(output, "sequence<42>"); +} + +TEST_F(PrintSequenceTest, PrintEmptySequence) +{ + // Test printing sequence<> (empty sequence) + constexpr auto seq = sequence<>{}; + + std::string output = CapturePrintOutput(seq); + + EXPECT_EQ(output, "sequence<>"); +} + +} // namespace ck_tile diff --git a/test/ck_tile/utility/print/test_print_static_encoding_pattern.cpp b/test/ck_tile/utility/print/test_print_static_encoding_pattern.cpp new file mode 100644 index 0000000000..d1cb408b5c --- /dev/null +++ b/test/ck_tile/utility/print/test_print_static_encoding_pattern.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_print_common.hpp" +#include "ck_tile/core/algorithm/static_encoding_pattern.hpp" +#include "ck_tile/core/utility/print.hpp" + +#include + +namespace ck_tile { + +class PrintStaticEncodingPatternTest : public PrintTest +{ + protected: + void TestY0Y1Y2(const std::string& output, auto Y0, auto Y1, auto Y2) + { + std::stringstream expected; + expected << ": <" << Y0 << ", " << Y1 << ", " << Y2 << ">"; + EXPECT_TRUE(output.find(expected.str()) != std::string::npos); + } + void TestX0X1(const std::string& output, auto X0, auto X1) + { + std::stringstream expected; + expected << ": <" << X0 << ", " << X1 << ">"; + EXPECT_TRUE(output.find(expected.str()) != std::string::npos); + } +}; + +TEST_F(PrintStaticEncodingPatternTest, PrintThreadRakedPattern) +{ + // Test printing thread raked pattern + using PatternType = + TileDistributionEncodingPattern2D<64, 8, 16, 4, tile_distribution_pattern::thread_raked>; + PatternType pattern; + + std::string output = CapturePrintOutput(pattern); + + // Verify the output contains expected information + EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos); + EXPECT_TRUE(output.find("BlockSize:64") != std::string::npos); + EXPECT_TRUE(output.find("YPerTile:8") != std::string::npos); + EXPECT_TRUE(output.find("XPerTile:16") != std::string::npos); + EXPECT_TRUE(output.find("VecSize:4") != std::string::npos); + EXPECT_TRUE(output.find("thread_raked") != std::string::npos); + TestY0Y1Y2(output, PatternType::Y0, PatternType::Y1, PatternType::Y2); + TestX0X1(output, PatternType::X0, PatternType::X1); +} + +TEST_F(PrintStaticEncodingPatternTest, PrintWarpRakedPattern) +{ + // Test printing warp raked pattern + using PatternType = + TileDistributionEncodingPattern2D<128, 16, 32, 8, tile_distribution_pattern::warp_raked>; + PatternType pattern; + + std::string output = CapturePrintOutput(pattern); + + // Verify the output contains expected information + EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos); + EXPECT_TRUE(output.find("BlockSize:128") != std::string::npos); + EXPECT_TRUE(output.find("YPerTile:16") != std::string::npos); + EXPECT_TRUE(output.find("XPerTile:32") != std::string::npos); + EXPECT_TRUE(output.find("VecSize:8") != std::string::npos); + EXPECT_TRUE(output.find("warp_raked") != std::string::npos); + TestY0Y1Y2(output, PatternType::Y0, PatternType::Y1, PatternType::Y2); + TestX0X1(output, PatternType::X0, PatternType::X1); +} + +TEST_F(PrintStaticEncodingPatternTest, PrintBlockRakedPattern) +{ + // Test printing block raked pattern + using PatternType = + TileDistributionEncodingPattern2D<256, 32, 64, 16, tile_distribution_pattern::block_raked>; + PatternType pattern; + + std::string output = CapturePrintOutput(pattern); + + // Verify the output contains expected information + EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos); + EXPECT_TRUE(output.find("BlockSize:256") != std::string::npos); + EXPECT_TRUE(output.find("YPerTile:32") != std::string::npos); + EXPECT_TRUE(output.find("XPerTile:64") != std::string::npos); + EXPECT_TRUE(output.find("VecSize:16") != std::string::npos); + EXPECT_TRUE(output.find("block_raked") != std::string::npos); + TestY0Y1Y2(output, PatternType::Y0, PatternType::Y1, PatternType::Y2); + TestX0X1(output, PatternType::X0, PatternType::X1); +} + +} // namespace ck_tile diff --git a/test/ck_tile/utility/print/test_print_tuple.cpp b/test/ck_tile/utility/print/test_print_tuple.cpp new file mode 100644 index 0000000000..79aaf1b3af --- /dev/null +++ b/test/ck_tile/utility/print/test_print_tuple.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_print_common.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/print.hpp" + +namespace ck_tile { + +class PrintTupleTest : public PrintTest +{ +}; + +TEST_F(PrintTupleTest, PrintSimpleTuple) +{ + // Test printing tuple with numbers + auto tup = make_tuple(number<1>{}, number<5>{}, number<8>{}); + + std::string output = CapturePrintOutput(tup); + + // Verify the output format matches tuple print implementation + EXPECT_TRUE(output.find("tuple<") == 0); + EXPECT_TRUE(output.find("1") != std::string::npos); + EXPECT_TRUE(output.find("5") != std::string::npos); + EXPECT_TRUE(output.find("8") != std::string::npos); + EXPECT_TRUE(output.back() == '>'); +} + +TEST_F(PrintTupleTest, PrintSingleElementTuple) +{ + // Test printing tuple with single element + auto tup = make_tuple(number<42>{}); + + std::string output = CapturePrintOutput(tup); + + EXPECT_TRUE(output.find("tuple<") == 0); + EXPECT_TRUE(output.find("42") != std::string::npos); + EXPECT_TRUE(output.back() == '>'); +} + +TEST_F(PrintTupleTest, PrintEmptyTuple) +{ + // Test printing empty tuple + auto tup = make_tuple(); + + std::string output = CapturePrintOutput(tup); + + EXPECT_EQ(output, "tuple<>"); +} + +TEST_F(PrintTupleTest, PrintMixedTypeTuple) +{ + // Test printing tuple with mixed types (numbers and constants) + auto tup = make_tuple(number<10>{}, constant<20>{}, number<30>{}); + + std::string output = CapturePrintOutput(tup); + + EXPECT_TRUE(output.find("tuple<") == 0); + EXPECT_TRUE(output.find("10") != std::string::npos); + EXPECT_TRUE(output.find("20") != std::string::npos); + EXPECT_TRUE(output.find("30") != std::string::npos); + EXPECT_TRUE(output.back() == '>'); +} + +} // namespace ck_tile