[CK_TILE] Enable printing more structures in CK-Tile (#2443)

* Add more printing to core cktile

* Revert other changes in static encoding pattern

* Refactor to using a free print() function

* Remove loops and print just the containers

* Print tuple with better formatting, fix sequence compilation

* Add some tests for print utility

* Add print utility header

* Print for static_encoding_pattern

* add buffer_view printing

* Align vector_traits

* Fix formatting

* Lower-case enum strings

Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com>

* Remove empty comment lines

* Fix test with lower-case too

* Reduce repeated code in print tests, move helper function closer to type definition, test X&Y

* Add test_print_common.hpp

* add print.hpp in core.hpp

---------

Co-authored-by: Aviral Goel <aviral.goel@amd.com>
Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

[ROCm/composable_kernel commit: ffdee5e774]
This commit is contained in:
Sami Remes
2025-08-07 15:45:27 +03:00
committed by GitHub
parent bdafbd7ca1
commit 1777ce3229
28 changed files with 1211 additions and 531 deletions

View File

@@ -74,6 +74,7 @@
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/print.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/static_counter.hpp"

View File

@@ -9,6 +9,7 @@
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
@@ -139,20 +140,19 @@ struct pass_through : public base_transform<1, 1>
{
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("pass_through{");
//
printf("up_lengths_:");
print(up_lengths_);
//
printf("}");
}
};
template <typename LowLength>
CK_TILE_HOST_DEVICE static void print(const pass_through<LowLength>& pt)
{
printf("pass_through{");
printf("up_lengths_: ");
print(pt.get_upper_lengths());
printf("}");
}
template <typename LowLength,
typename LeftPadLength,
typename RightPadLength,
@@ -229,29 +229,25 @@ struct pad : public base_transform<1, 1>
ck_tile::is_known_at_compile_time<LeftPadLength>::value &&
ck_tile::is_known_at_compile_time<RightPadLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}");
}
};
template <typename LowLength,
typename LeftPadLength,
typename RightPadLength,
bool SkipIsValidCheck>
CK_TILE_HOST_DEVICE static void
print(const pad<LowLength, LeftPadLength, RightPadLength, SkipIsValidCheck>& p)
{
printf("pad{");
printf("up_lengths_: ");
print(p.up_lengths_);
printf(", left_pad_length_: ");
print(p.left_pad_length_);
printf(", right_pad_length_: ");
print(p.right_pad_length_);
printf("}");
}
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
struct left_pad
{
@@ -330,24 +326,20 @@ struct left_pad
// It's up to runtime to check the padding length should be multiple of vector length
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("left_pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf("}");
}
};
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck>
CK_TILE_HOST_DEVICE static void
print(const left_pad<LowLength, LeftPadLength, SkipIsValidCheck>& lp)
{
printf("left_pad{");
printf("up_lengths_: ");
print(lp.up_lengths_);
printf(", left_pad_length_: ");
print(lp.left_pad_length_);
printf("}");
}
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
struct right_pad : public base_transform<1, 1>
{
@@ -430,24 +422,20 @@ struct right_pad : public base_transform<1, 1>
// It's up to runtime to check the padding length should be multiple of vector length
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("right_pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}");
}
};
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
CK_TILE_HOST_DEVICE static void
print(const right_pad<LowLength, RightPadLength, SkipIsValidCheck>& rp)
{
printf("right_pad{");
printf("up_lengths_: ");
print(rp.up_lengths_);
printf(", right_pad_length_: ");
print(rp.right_pad_length_);
printf("}");
}
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
// UpLengths and Coefficients can be either of the followings:
// 1) Tuple of index_t, which is known at run-time, or
@@ -532,24 +520,19 @@ struct embed : public base_transform<1, UpLengths::size()>
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<Coefficients>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("coefficients_: ");
print(coefficients_);
printf("}");
}
};
template <typename UpLengths, typename Coefficients>
CK_TILE_HOST_DEVICE static void print(const embed<UpLengths, Coefficients>& e)
{
printf("embed{");
printf("up_lengths_: ");
print(e.up_lengths_);
printf(", coefficients_: ");
print(e.coefficients_);
printf("}");
}
template <typename LowLengths>
struct lambda_merge_generate_MagicDivision_calculate_magic_divisor
{
@@ -699,24 +682,19 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("merge_v2_magic_division{");
//
printf("low_lengths_ ");
print(low_lengths_);
printf(", ");
//
printf("up_lengths_ ");
print(up_lengths_);
printf("}");
}
};
template <typename LowLengths>
CK_TILE_HOST_DEVICE static void print(const merge_v2_magic_division<LowLengths>& m)
{
printf("merge_v2_magic_division{");
printf("low_lengths_: ");
print(m.low_lengths_);
printf(", up_lengths_: ");
print(m.up_lengths_);
printf("}");
}
// Implementation of "merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
@@ -830,29 +808,21 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("Merge_v3_direct_division_mod{");
//
printf("low_lengths_ ");
print(low_lengths_);
printf(", ");
//
printf("low_lengths_scan_ ");
print(low_lengths_scan_);
printf(", ");
//
printf("up_lengths_ ");
print(up_lengths_);
printf("}");
}
};
template <typename LowLengths>
CK_TILE_HOST_DEVICE static void print(const merge_v3_division_mod<LowLengths>& m)
{
printf("merge_v3_division_mod{");
printf("low_lengths_: ");
print(m.low_lengths_);
printf(", low_lengths_scan_: ");
print(m.low_lengths_scan_);
printf(", up_lengths_: ");
print(m.up_lengths_);
printf("}");
}
template <typename UpLengths, bool Use24BitIntegerCalculation>
struct unmerge : public base_transform<1, UpLengths::size()>
{
@@ -958,24 +928,19 @@ struct unmerge : public base_transform<1, UpLengths::size()>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("unmerge{");
//
printf("up_lengths_");
print(up_lengths_);
printf(", ");
//
printf("up_lengths_scan_");
print(up_lengths_scan_);
printf("}");
}
};
template <typename UpLengths, bool Use24BitIntegerCalculation>
CK_TILE_HOST_DEVICE static void print(const unmerge<UpLengths, Use24BitIntegerCalculation>& u)
{
printf("unmerge{");
printf("up_lengths_: ");
print(u.up_lengths_);
printf(", up_lengths_scan_: ");
print(u.up_lengths_scan_);
printf("}");
}
template <typename LowerIndex>
struct freeze : public base_transform<1, 0>
{
@@ -1023,19 +988,17 @@ struct freeze : public base_transform<1, 0>
{
return ck_tile::is_known_at_compile_time<LowerIndex>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("freeze{");
//
printf("low_idx_: ");
print(low_idx_);
printf("}");
}
};
template <typename LowerIndex>
CK_TILE_HOST_DEVICE static void print(const freeze<LowerIndex>& f)
{
printf("freeze{");
printf("low_idx_: ");
print(f.low_idx_);
printf("}");
}
// insert a dangling upper dimension without lower dimension
template <typename UpperLength>
struct insert : public base_transform<0, 1>
@@ -1092,18 +1055,17 @@ struct insert : public base_transform<0, 1>
{
return ck_tile::is_known_at_compile_time<UpperLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("insert{");
//
print(up_lengths_);
printf("}");
}
};
template <typename UpperLength>
CK_TILE_HOST_DEVICE static void print(const insert<UpperLength>& i)
{
printf("insert{");
printf("up_lengths_: ");
print(i.up_lengths_);
printf("}");
}
// replicate the original tensor and create a higher dimensional tensor
template <typename UpLengths>
struct replicate : public base_transform<0, UpLengths::size()>
@@ -1152,21 +1114,19 @@ struct replicate : public base_transform<0, UpLengths::size()>
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("replicate{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
}
//
UpLengths up_lengths_;
};
template <typename UpLengths>
CK_TILE_HOST_DEVICE static void print(const replicate<UpLengths>& r)
{
printf("replicate{");
printf("up_lengths_: ");
print(r.up_lengths_);
printf("}");
}
template <typename LowLength, typename SliceBegin, typename SliceEnd>
struct slice : public base_transform<1, 1>
{
@@ -1238,28 +1198,20 @@ struct slice : public base_transform<1, 1>
ck_tile::is_known_at_compile_time<SliceBegin>::value &&
ck_tile::is_known_at_compile_time<SliceEnd>::value;
}
};
CK_TILE_HOST_DEVICE void print() const
{
printf("slice{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("slice_begin_: ");
print(slice_begin_);
printf(", ");
//
printf("slice_end_: ");
print(slice_end_);
printf("}");
} // namespace ck
}; // namespace ck
template <typename LowLength, typename SliceBegin, typename SliceEnd>
CK_TILE_HOST_DEVICE static void print(const slice<LowLength, SliceBegin, SliceEnd>& s)
{
printf("slice{");
printf("up_lengths_: ");
print(s.up_lengths_);
printf(", slice_begin_: ");
print(s.slice_begin_);
printf(", slice_end_: ");
print(s.slice_end_);
printf("}");
}
/*
* \brief lower_idx = upper_idx % modulus.
@@ -1328,19 +1280,19 @@ struct modulo : public base_transform<1, 1>
{
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("Modulus{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
}
};
template <typename Modulus, typename UpLength>
CK_TILE_HOST_DEVICE static void print(const modulo<Modulus, UpLength>& m)
{
printf("modulo{");
printf("modulus_: ");
print(m.modulus_);
printf(", up_lengths_: ");
print(m.up_lengths_);
printf("}");
}
// 2D XOR, NOTE: "xor" is a keyword
template <typename LowLengths>
struct xor_t : public base_transform<2, 2>
@@ -1424,20 +1376,17 @@ struct xor_t : public base_transform<2, 2>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("xor_t{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
printf("}");
}
};
template <typename LowLengths>
CK_TILE_HOST_DEVICE static void print(const xor_t<LowLengths>& x)
{
printf("xor_t{");
printf("up_lengths_: ");
print(x.up_lengths_);
printf("}");
}
template <typename LowLength, typename OffsetLength>
struct offset : public base_transform<1, 1>
{
@@ -1509,24 +1458,19 @@ struct offset : public base_transform<1, 1>
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<OffsetLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("offset{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("offset_length_: ");
print(offset_length_);
printf("}");
}
};
template <typename LowLength, typename OffsetLength>
CK_TILE_HOST_DEVICE static void print(const offset<LowLength, OffsetLength>& o)
{
printf("offset{");
printf("up_lengths_: ");
print(o.up_lengths_);
printf(", offset_length_: ");
print(o.offset_length_);
printf("}");
}
template <typename UpLength, typename IndexingAdaptor>
struct indexing : public base_transform<1, 1>
{
@@ -1595,20 +1539,19 @@ struct indexing : public base_transform<1, 1>
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
IndexingAdaptor::is_known_at_compile_time();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
printf("}");
}
};
template <typename UpLength, typename IndexingAdaptor>
CK_TILE_HOST_DEVICE static void print(const indexing<UpLength, IndexingAdaptor>& i)
{
printf("indexing{");
printf("up_lengths_: ");
print(i.up_lengths_);
printf(", iadaptor_: ");
print(i.iadaptor_);
printf("}");
}
//*******************************************************************************************************
template <typename LowLength>

View File

@@ -77,6 +77,7 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
@@ -317,4 +318,51 @@ struct TileDistributionEncodingPattern2D<BlockSize,
}
};
// Helper function to convert enum to string
constexpr const char* tile_distribution_pattern_to_string(tile_distribution_pattern pattern)
{
switch(pattern)
{
case tile_distribution_pattern::thread_raked: return "thread_raked";
case tile_distribution_pattern::warp_raked: return "warp_raked";
case tile_distribution_pattern::block_raked: return "block_raked";
default: return "unknown";
}
}
template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
tile_distribution_pattern DistributionPattern,
index_t NumWaveGroups>
CK_TILE_HOST_DEVICE void print(const TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
DistributionPattern,
NumWaveGroups>&)
{
using PatternType = TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
DistributionPattern,
NumWaveGroups>;
printf("TileDistributionEncodingPattern2D<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
"VecSize:%d, %s>: ",
BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern_to_string(DistributionPattern));
printf("{<Y0, Y1, Y2>: <%d, %d, %d>, <X0, X1>: <%d, %d>}\n",
PatternType::Y0,
PatternType::Y1,
PatternType::Y2,
PatternType::X0,
PatternType::X1);
}
} // namespace ck_tile

View File

@@ -218,4 +218,19 @@ CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity()
#endif
}
/// Helper function to convert address space enum to string
CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_enum addr_space)
{
switch(addr_space)
{
case address_space_enum::generic: return "generic";
case address_space_enum::global: return "global";
case address_space_enum::lds: return "lds";
case address_space_enum::sgpr: return "sgpr";
case address_space_enum::constant: return "constant";
case address_space_enum::vgpr: return "vgpr";
default: return "unknown";
}
}
} // namespace ck_tile

View File

@@ -177,9 +177,27 @@ struct array<T, 0>
CK_TILE_HOST_DEVICE constexpr array() {}
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
};
template <typename T, index_t N>
CK_TILE_HOST_DEVICE static void print(const array<T, N>& a)
{
printf("array{size: %ld, data: [", static_cast<long>(N));
for(index_t i = 0; i < N; ++i)
{
if(i > 0)
printf(", ");
print(a[i]);
}
printf("]}");
}
template <typename T>
CK_TILE_HOST_DEVICE static void print(const array<T, 0>&)
{
printf("array{size: 0, data: []}");
}
template <typename, typename>
struct vector_traits;

View File

@@ -139,26 +139,21 @@ struct map
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
CK_TILE_HOST_DEVICE void print() const
{
printf("map{size_: %d, ", size_);
//
printf("impl_: [");
//
for(const auto& [k, d] : *this)
{
printf("{key: ");
print(k);
printf(", data: ");
print(d);
printf("}, ");
}
//
printf("]");
//
printf("}");
}
};
template <typename key, typename data, index_t max_size>
CK_TILE_HOST_DEVICE static void print(const map<key, data, max_size>& m)
{
printf("map{size_: %d, impl_: [", m.size_);
for(const auto& [k, d] : m)
{
printf("{key: ");
print(k);
printf(", data: ");
print(d);
printf("}, ");
}
printf("]}");
}
} // namespace ck_tile

View File

@@ -9,13 +9,10 @@
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
template <index_t, index_t, index_t>
struct static_for;
template <index_t...>
struct sequence;
@@ -196,15 +193,24 @@ struct sequence
{
return sequence<f(Is)...>{};
}
CK_TILE_HOST_DEVICE static void print()
{
printf("sequence{size: %d, data: [", size());
((printf("%d ", Is)), ...);
printf("]}");
}
};
template <index_t... Is>
CK_TILE_HOST_DEVICE static void print(const sequence<Is...>&)
{
printf("sequence<");
if constexpr(sizeof...(Is) > 0)
{
bool first = true;
(([&first](index_t value) {
printf("%s%d", first ? "" : ", ", value);
first = false;
}(Is)),
...);
}
printf(">");
}
namespace impl {
template <typename T, T... Ints>
struct __integer_sequence;

View File

@@ -300,12 +300,29 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
#undef TP_COM_
};
template <typename, typename = void>
template <typename... T>
CK_TILE_HOST_DEVICE void print(const tuple<T...>& t)
{
printf("tuple<");
if constexpr(sizeof...(T) > 0)
{
bool first = true;
static_for<0, sizeof...(T), 1>{}([&t, &first](auto i) {
if(!first)
printf(", ");
print(t.get(i));
first = false;
});
}
printf(">");
}
template <typename, typename>
struct vector_traits;
// specialization for array
template <typename... T>
struct vector_traits<tuple<T...>>
struct vector_traits<tuple<T...>, void>
{
using scalar_type = __type_pack_element<0, T...>;
static constexpr index_t vector_size = sizeof...(T);

View File

@@ -19,14 +19,18 @@ struct constant
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
};
template <auto v>
CK_TILE_HOST_DEVICE static void print(const constant<v>&)
{
printf("%ld", static_cast<long>(v));
}
template <typename T, T v>
struct integral_constant : constant<v>
{
using value_type = T;
using type = integral_constant; // using injected-class-name
static constexpr T value = v;
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
};
template <index_t v>

View File

@@ -84,7 +84,7 @@ using ext_vector_t = typename impl::ext_vector<T, N>::type;
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
// ... unless we have other vector_traits specialization
template <typename T, typename>
template <typename T, typename = void>
struct vector_traits
{
using scalar_type =
@@ -94,7 +94,7 @@ struct vector_traits
// specialization for ext_vector_type()
template <typename T, index_t N>
struct vector_traits<T __attribute__((ext_vector_type(N)))>
struct vector_traits<T __attribute__((ext_vector_type(N))), void>
{
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
static constexpr index_t vector_size = N;

View File

@@ -210,28 +210,6 @@ struct buffer_view<address_space_enum::generic,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: generic, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: Global
@@ -757,28 +735,6 @@ struct buffer_view<address_space_enum::global,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Global, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: LDS
@@ -1138,28 +1094,6 @@ struct buffer_view<address_space_enum::lds,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Lds, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: Vgpr
@@ -1313,28 +1247,6 @@ struct buffer_view<address_space_enum::vgpr,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Vgpr, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
template <address_space_enum BufferAddressSpace,
@@ -1360,4 +1272,25 @@ make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value)
p, buffer_size, invalid_element_value};
}
// Generalized print function for all buffer_view variants
template <address_space_enum BufferAddressSpace,
typename T,
typename BufferSizeType,
bool InvalidElementUseNumericalZeroValue,
amd_buffer_coherence_enum Coherence>
CK_TILE_HOST_DEVICE void print(const buffer_view<BufferAddressSpace,
T,
BufferSizeType,
InvalidElementUseNumericalZeroValue,
Coherence>& bv)
{
printf("buffer_view{AddressSpace: %s, p_data_: %p, buffer_size_: ",
address_space_to_string(BufferAddressSpace),
static_cast<void*>(const_cast<remove_cvref_t<T>*>(bv.p_data_)));
print(bv.buffer_size_);
printf(", invalid_element_value_: ");
print(bv.invalid_element_value_);
printf("}");
}
} // namespace ck_tile

View File

@@ -305,42 +305,45 @@ struct tensor_adaptor
get_container_subset(vector_strides, top_dims));
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_adaptor{");
//
printf("transforms: ");
print(transforms_);
printf(", ");
//
printf("LowerDimensionHiddenIds: ");
print(LowerDimensionHiddenIdss{});
printf(", ");
//
printf("UpperDimensionHiddenIds: ");
print(UpperDimensionHiddenIdss{});
printf(", ");
//
printf("BottomDimensionHiddenIds: ");
print(BottomDimensionHiddenIds{});
printf(", ");
//
printf("TopDimensionHiddenIds: ");
print(TopDimensionHiddenIds{});
printf("}");
}
private:
Transforms transforms_;
ElementSize element_size_;
};
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename BottomDimensionHiddenIds,
typename TopDimensionHiddenIds>
CK_TILE_HOST_DEVICE static void print(const tensor_adaptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
BottomDimensionHiddenIds,
TopDimensionHiddenIds>& adaptor)
{
printf("tensor_adaptor{\n");
printf(" transforms: [");
print(adaptor.get_transforms());
printf("],\n");
printf(" LowerDimensionHiddenIds: [");
print(LowerDimensionHiddenIdss{});
printf("],\n");
printf(" UpperDimensionHiddenIds: [");
print(UpperDimensionHiddenIdss{});
printf("],\n");
printf(" BottomDimensionHiddenIds: [");
print(BottomDimensionHiddenIds{});
printf("],\n");
//
printf(" TopDimensionHiddenIds: [");
print(TopDimensionHiddenIds{});
printf("]\n}\n");
}
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>

View File

@@ -140,25 +140,37 @@ struct tensor_descriptor : public tensor_adaptor<Transforms,
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_descriptor{");
// tensor_adaptor
Base::print();
printf(", ");
// element_space_size_
printf("element_space_size_: ");
print(element_space_size_);
printf("}");
}
// TODO make these private
ElementSpaceSize element_space_size_;
};
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename TopDimensionHiddenIds,
typename ElementSpaceSize,
typename GuaranteedVectorLengths,
typename GuaranteedVectorStrides>
CK_TILE_HOST_DEVICE static void print(const tensor_descriptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
TopDimensionHiddenIds,
ElementSpaceSize,
GuaranteedVectorLengths,
GuaranteedVectorStrides>& descriptor)
{
printf("tensor_descriptor{\n");
// first print the tensor adaptor part of the descriptor using the base class print
print(static_cast<const typename decltype(descriptor)::Base&>(descriptor));
printf("element_space_size_: %ld,\n",
static_cast<long>(descriptor.get_element_space_size().value));
printf("guaranteed_vector_lengths: ");
print(GuaranteedVectorLengths{});
printf(",\nguaranteed_vector_strides: ");
print(GuaranteedVectorStrides{});
printf("}\n}\n");
}
template <typename Adaptor, typename ElementSpaceSize>
CK_TILE_HOST_DEVICE constexpr auto
make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,

View File

@@ -228,24 +228,6 @@ struct tile_distribution
{
return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution{");
//
printf("tile_distribution_encoding: ");
print(DstrEncode{});
printf(", ");
//
printf("ps_ys_to_xs_: ");
print(ps_ys_to_xs_);
printf(", ");
//
printf("ys_to_d_: ");
print(ys_to_d_);
//
printf("}");
}
};
namespace detail {
@@ -710,4 +692,27 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
}
} // namespace detail
// Free print function for tile_distribution
template <typename PsYs2XsAdaptor_,
typename Ys2DDescriptor_,
typename StaticTileDistributionEncoding_,
typename TileDistributionDetail_>
CK_TILE_HOST_DEVICE void print(const tile_distribution<PsYs2XsAdaptor_,
Ys2DDescriptor_,
StaticTileDistributionEncoding_,
TileDistributionDetail_>& distribution)
{
printf("tile_distribution{");
printf("tile_distribution_encoding: ");
print(StaticTileDistributionEncoding_{});
printf(", ");
printf("ps_ys_to_xs_: ");
print(distribution.ps_ys_to_xs_);
printf(", ");
printf("ys_to_d_: ");
print(distribution.ys_to_d_);
printf("}\n");
}
} // namespace ck_tile

View File

@@ -428,109 +428,7 @@ struct tile_distribution_encoding
{
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution_encoding::detail{");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("ndim_span_major_: ");
print(ndim_span_major_);
printf(", ");
//
printf("ndims_rhs_minor_: ");
print(ndims_rhs_minor_);
printf(", ");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("max_ndim_rh_minor_: ");
print(max_ndim_rh_minor_);
printf(", ");
//
printf("rhs_lengthss_: ");
print(rhs_lengthss_);
printf(", ");
//
printf("ys_lengths_: ");
print(ys_lengths_);
printf(", ");
//
printf("rhs_major_minor_to_ys_: ");
print(rhs_major_minor_to_ys_);
printf(", ");
//
printf("ndims_span_minor_: ");
print(ndims_span_minor_);
printf(", ");
//
printf("max_ndim_span_minor_: ");
print(max_ndim_span_minor_);
printf(", ");
//
printf("ys_to_span_major_: ");
print(ys_to_span_major_);
printf(", ");
//
printf("ys_to_span_minor_: ");
print(ys_to_span_minor_);
printf(", ");
//
printf("distributed_spans_lengthss_: ");
print(distributed_spans_lengthss_);
printf(", ");
//
printf("ndims_distributed_spans_minor_: ");
print(ndims_distributed_spans_minor_);
printf(", ");
//
printf("ps_over_rs_derivative_: ");
print(ps_over_rs_derivative_);
//
printf("}");
}
};
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution_encoding{");
//
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
//
printf("rs_lengths_: ");
print(rs_lengths_);
printf(", ");
//
printf("hs_lengthss_: ");
print(hs_lengthss_);
printf(", ");
//
printf("ps_to_rhss_major_: ");
print(ps_to_rhss_major_);
printf(", ");
//
printf("ps_to_rhss_minor_: ");
print(ps_to_rhss_minor_);
printf(", ");
//
printf("ys_to_rhs_major_: ");
print(ys_to_rhs_major_);
printf(", ");
//
printf("ys_to_rhs_minor_: ");
print(ys_to_rhs_minor_);
printf(", ");
//
printf("detail: ");
print(detail{});
//
printf("}");
}
};
template <typename encoding, typename shuffle>
@@ -896,4 +794,106 @@ make_reduce_tile_distribution_encoding(InDstr, sequence<InReduceDimXs...> reduce
}
} // namespace detail
// Free print function for tile_distribution_encoding::detail
template <typename RsLengths_,
typename HsLengthss_,
typename Ps2RHssMajor_,
typename Ps2RHssMinor_,
typename Ys2RHsMajor_,
typename Ys2RHsMinor_>
CK_TILE_HOST_DEVICE void
print(const typename tile_distribution_encoding<RsLengths_,
HsLengthss_,
Ps2RHssMajor_,
Ps2RHssMinor_,
Ys2RHsMajor_,
Ys2RHsMinor_>::detail& detail_obj)
{
printf("tile_distribution_encoding::detail{");
printf("ndim_rh_major_: ");
print(detail_obj.ndim_rh_major_);
printf(", ");
printf("ndim_span_major_: ");
print(detail_obj.ndim_span_major_);
printf(", ");
printf("ndims_rhs_minor_: ");
print(detail_obj.ndims_rhs_minor_);
printf(", ");
printf("ndim_rh_major_: ");
print(detail_obj.ndim_rh_major_);
printf(", ");
printf("max_ndim_rh_minor_: ");
print(detail_obj.max_ndim_rh_minor_);
printf(", ");
printf("rhs_lengthss_: ");
print(detail_obj.rhs_lengthss_);
printf(", ");
printf("ys_lengths_: ");
print(detail_obj.ys_lengths_);
printf(", ");
printf("rhs_major_minor_to_ys_: ");
print(detail_obj.rhs_major_minor_to_ys_);
printf(", ");
printf("ndims_span_minor_: ");
print(detail_obj.ndims_span_minor_);
printf(", ");
printf("max_ndim_span_minor_: ");
print(detail_obj.max_ndim_span_minor_);
printf(", ");
printf("ys_to_span_major_: ");
print(detail_obj.ys_to_span_major_);
printf(", ");
printf("ys_to_span_minor_: ");
print(detail_obj.ys_to_span_minor_);
printf(", ");
printf("distributed_spans_lengthss_: ");
print(detail_obj.distributed_spans_lengthss_);
printf(", ");
printf("ndims_distributed_spans_minor_: ");
print(detail_obj.ndims_distributed_spans_minor_);
printf(", ");
printf("ps_over_rs_derivative_: ");
print(detail_obj.ps_over_rs_derivative_);
printf("}");
}
// Free print function for tile_distribution_encoding
template <typename RsLengths_,
typename HsLengthss_,
typename Ps2RHssMajor_,
typename Ps2RHssMinor_,
typename Ys2RHsMajor_,
typename Ys2RHsMinor_>
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding<RsLengths_,
HsLengthss_,
Ps2RHssMajor_,
Ps2RHssMinor_,
Ys2RHsMajor_,
Ys2RHsMinor_>& encoding)
{
printf("tile_distribution_encoding{");
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", encoding.NDimX, encoding.NDimP, encoding.NDimY);
printf("rs_lengths_: ");
print(encoding.rs_lengths_);
printf(", ");
printf("hs_lengthss_: ");
print(encoding.hs_lengthss_);
printf(", ");
printf("ps_to_rhss_major_: ");
print(encoding.ps_to_rhss_major_);
printf(", ");
printf("ps_to_rhss_minor_: ");
print(encoding.ps_to_rhss_minor_);
printf(", ");
printf("ys_to_rhs_major_: ");
print(encoding.ys_to_rhs_major_);
printf(", ");
printf("ys_to_rhs_minor_: ");
print(encoding.ys_to_rhs_minor_);
printf(", ");
printf("}");
}
} // namespace ck_tile

View File

@@ -0,0 +1,76 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
/// Declare a ck_tile::print() interface that gets specialized in each header file for types that
/// can be printed.
template <typename T>
CK_TILE_HOST_DEVICE void print(const T&)
{
static_assert(sizeof(T) == 0,
"No print implementation available for this type. Please specialize "
"ck_tile::print for your type.");
}
/// Specialization for int
template <>
CK_TILE_HOST_DEVICE void print(const int& value)
{
printf("%d", value);
}
/// Specialization for float
template <>
CK_TILE_HOST_DEVICE void print(const float& value)
{
printf("%f", value);
}
/// Specialization for double
template <>
CK_TILE_HOST_DEVICE void print(const double& value)
{
printf("%f", value);
}
/// Specialization for long
template <>
CK_TILE_HOST_DEVICE void print(const long& value)
{
printf("%ld", value);
}
/// Specialization for unsigned int
template <>
CK_TILE_HOST_DEVICE void print(const unsigned int& value)
{
printf("%u", value);
}
/// Specialization for char
template <>
CK_TILE_HOST_DEVICE void print(const char& value)
{
printf("%c", value);
}
/// Specialization for array
template <typename T, size_t N>
CK_TILE_HOST_DEVICE void print(const T (&value)[N])
{
printf("[");
for(size_t i = 0; i < N; ++i)
{
if(i > 0)
printf(", ");
print(value[i]); // Recursively call print for each element
}
printf("]");
}
} // namespace ck_tile

View File

@@ -21,4 +21,5 @@ add_subdirectory(add_rmsnorm2d_rdquant)
# add_subdirectory(layernorm2d)
# add_subdirectory(rmsnorm2d)
add_subdirectory(gemm_block_scale)
add_subdirectory(reduce)
add_subdirectory(utility)
add_subdirectory(reduce)

View File

@@ -0,0 +1,4 @@
message("-- Adding: test/ck_tile/utility/")
# Add print tests
add_subdirectory(print)

View File

@@ -0,0 +1,8 @@
# Print utility tests
add_gtest_executable(test_print_sequence test_print_sequence.cpp)
add_gtest_executable(test_print_array test_print_array.cpp)
add_gtest_executable(test_print_tuple test_print_tuple.cpp)
add_gtest_executable(test_print_coordinate_transform test_print_coordinate_transform.cpp)
add_gtest_executable(test_print_static_encoding_pattern test_print_static_encoding_pattern.cpp)
add_gtest_executable(test_print_buffer_view test_print_buffer_view.cpp)
add_gtest_executable(test_print_basic_types test_print_basic_types.cpp)

View File

@@ -0,0 +1,70 @@
# Print Function Tests
This directory contains unit tests for testing the print functionality of various data structures and coordinate transformations in the composable_kernel library.
## Tests Included
### test_print_sequence.cpp
Tests the print functionality for `sequence<...>` containers:
- Simple sequences with multiple elements
- Single element sequences
- Empty sequences
- Longer sequences
### test_print_array.cpp
Tests the print functionality for `array<T, N>` containers:
- Arrays with integer values
- Single element arrays
- Empty arrays (size 0)
- Arrays with floating point values
### test_print_tuple.cpp
Tests the print functionality for `tuple<...>` containers:
- Simple tuples with numbers
- Single element tuples
- Empty tuples
- Mixed type tuples
### test_print_coordinate_transform.cpp
Tests the print functionality for coordinate transformation structures:
- `pass_through` transform
- `embed` transform
- `merge` transform
- `unmerge` transform
- `freeze` transform
## Testing Approach
All tests use Google Test's `CaptureStdout()` functionality to capture the output from print functions and verify the formatting:
```cpp
testing::internal::CaptureStdout();
print(object);
std::string output = testing::internal::GetCapturedStdout();
EXPECT_EQ(output, "expected_format");
```
This approach enables testing of print function output without affecting the console during test execution.
## Building and Running
The tests are integrated into the CMake build system. To build and run the print tests:
```bash
# Build the specific test
make test_print_sequence
# Run the test
./test_print_sequence
# Or run all print tests using CTest
ctest -R "test_print"
```
## Adding New Tests
To add tests for new data structures:
1. Create a new test file: `test_print_<structure_name>.cpp`
2. Follow the existing pattern using `CaptureStdout()`
3. Add the test executable to `CMakeLists.txt`

View File

@@ -0,0 +1,59 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_print_common.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
class PrintArrayTest : public PrintTest
{
};
TEST_F(PrintArrayTest, PrintIntArray)
{
// Test printing array<int, 3>
array<int, 3> arr{10, 20, 30};
std::string output = CapturePrintOutput(arr);
// The expected format should match the array print function implementation
EXPECT_EQ(output, "array{size: 3, data: [10, 20, 30]}");
}
TEST_F(PrintArrayTest, PrintSingleElementArray)
{
// Test printing array<int, 1>
array<int, 1> arr{42};
std::string output = CapturePrintOutput(arr);
EXPECT_EQ(output, "array{size: 1, data: [42]}");
}
TEST_F(PrintArrayTest, PrintEmptyArray)
{
// Test printing array<int, 0> (empty array)
array<int, 0> arr{};
std::string output = CapturePrintOutput(arr);
EXPECT_EQ(output, "array{size: 0, data: []}");
}
TEST_F(PrintArrayTest, PrintFloatArray)
{
// Test printing array with float values
array<float, 2> arr{3.14f, 2.71f};
std::string output = CapturePrintOutput(arr);
// Note: float printing format may vary, so we'll test for basic structure
EXPECT_TRUE(output.find("array{size: 2, data: [") == 0);
EXPECT_TRUE(output.find("3.14") != std::string::npos);
EXPECT_TRUE(output.find("2.71") != std::string::npos);
EXPECT_TRUE(output.find("]}") == output.length() - 2);
}
} // namespace ck_tile

View File

@@ -0,0 +1,76 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_print_common.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
class PrintBasicTypesTest : public PrintTest
{
};
TEST_F(PrintBasicTypesTest, PrintIntArray)
{
int arr[4] = {1, 2, 3, 4};
std::string output = CapturePrintOutput(arr);
EXPECT_EQ(output, "[1, 2, 3, 4]");
}
TEST_F(PrintBasicTypesTest, PrintFloatArray)
{
float arr[3] = {1.5f, 2.5f, 3.5f};
std::string output = CapturePrintOutput(arr);
// Note: floating point formatting may vary, so we check for key elements
EXPECT_TRUE(output.find("[") == 0);
EXPECT_TRUE(output.find("1.5") != std::string::npos);
EXPECT_TRUE(output.find("2.5") != std::string::npos);
EXPECT_TRUE(output.find("3.5") != std::string::npos);
EXPECT_TRUE(output.back() == ']');
EXPECT_TRUE(output.find(", ") != std::string::npos);
}
TEST_F(PrintBasicTypesTest, PrintDoubleArray)
{
double arr[2] = {10.123, 20.456};
std::string output = CapturePrintOutput(arr);
EXPECT_TRUE(output.find("[") == 0);
EXPECT_TRUE(output.find("10.123") != std::string::npos);
EXPECT_TRUE(output.find("20.456") != std::string::npos);
EXPECT_TRUE(output.back() == ']');
}
TEST_F(PrintBasicTypesTest, PrintUnsignedIntArray)
{
unsigned int arr[3] = {100u, 200u, 300u};
std::string output = CapturePrintOutput(arr);
EXPECT_EQ(output, "[100, 200, 300]");
}
TEST_F(PrintBasicTypesTest, PrintCharArray)
{
char arr[5] = {'a', 'b', 'c', 'd', 'e'};
std::string output = CapturePrintOutput(arr);
EXPECT_EQ(output, "[a, b, c, d, e]");
}
TEST_F(PrintBasicTypesTest, PrintSingleElementArray)
{
int arr[1] = {42};
std::string output = CapturePrintOutput(arr);
EXPECT_EQ(output, "[42]");
}
} // namespace ck_tile

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_print_common.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
class PrintBufferViewTest : public PrintTest
{
};
TEST_F(PrintBufferViewTest, PrintGenericBufferView)
{
// Test printing generic address space buffer_view
float data[4] = {100.f, 200.f, 300.f, 400.f};
auto bv = make_buffer_view<address_space_enum::generic>(&data, 4);
std::string output = CapturePrintOutput(bv);
// Verify the output contains expected information
EXPECT_TRUE(output.find("buffer_view{AddressSpace: generic") != std::string::npos);
EXPECT_TRUE(output.find("p_data_:") != std::string::npos);
EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos);
EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos);
EXPECT_TRUE(output.find("}") != std::string::npos);
}
TEST_F(PrintBufferViewTest, PrintGlobalBufferView)
{
// Test printing global address space buffer_view
float data[4] = {100.f, 200.f, 300.f, 400.f};
auto bv = make_buffer_view<address_space_enum::global>(&data, 4);
std::string output = CapturePrintOutput(bv);
// Verify the output contains expected information
EXPECT_TRUE(output.find("buffer_view{AddressSpace: global") != std::string::npos);
EXPECT_TRUE(output.find("p_data_:") != std::string::npos);
EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos);
EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos);
EXPECT_TRUE(output.find("}") != std::string::npos);
}
TEST_F(PrintBufferViewTest, PrintLdsBufferView)
{
// Test printing LDS address space buffer_view
float data[4] = {100.f, 200.f, 300.f, 400.f};
auto bv = make_buffer_view<address_space_enum::lds>(data, 4);
std::string output = CapturePrintOutput(bv);
// Verify the output contains expected information
EXPECT_TRUE(output.find("buffer_view{AddressSpace: lds") != std::string::npos);
EXPECT_TRUE(output.find("p_data_:") != std::string::npos);
EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos);
EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos);
EXPECT_TRUE(output.find("}") != std::string::npos);
}
TEST_F(PrintBufferViewTest, PrintVgprBufferView)
{
// Test printing VGPR address space buffer_view
float data[4] = {1.5f, 2.5f, 3.5f, 4.5f};
auto bv = make_buffer_view<address_space_enum::vgpr>(data, 4);
std::string output = CapturePrintOutput(bv);
// Verify the output contains expected information
EXPECT_TRUE(output.find("buffer_view{AddressSpace: vgpr") != std::string::npos);
EXPECT_TRUE(output.find("p_data_:") != std::string::npos);
EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos);
EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos);
EXPECT_TRUE(output.find("}") != std::string::npos);
}
} // namespace ck_tile

View File

@@ -0,0 +1,25 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <gtest/gtest.h>
#include <gtest/gtest-spi.h>
#include "ck_tile/core/utility/print.hpp"
class PrintTest : public ::testing::Test
{
protected:
void SetUp() override {}
void TearDown() override {}
// Helper function to capture and return the output of a print function
template <typename T>
std::string CapturePrintOutput(const T& type)
{
using namespace ck_tile;
testing::internal::CaptureStdout();
print(type);
return testing::internal::GetCapturedStdout();
}
};

View File

@@ -0,0 +1,83 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_print_common.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
class PrintCoordinateTransformTest : public PrintTest
{
};
TEST_F(PrintCoordinateTransformTest, PrintPassThrough)
{
// Test printing pass_through transform
auto pt = make_pass_through_transform(number<32>{});
std::string output = CapturePrintOutput(pt);
// Verify it contains the pass_through identifier and some structure
EXPECT_TRUE(output.find("pass_through{") == 0);
EXPECT_TRUE(output.find("up_lengths_") != std::string::npos);
EXPECT_TRUE(output.back() == '}');
}
TEST_F(PrintCoordinateTransformTest, PrintEmbed)
{
// Test printing embed transform
auto embed_transform = make_embed_transform(make_tuple(number<4>{}, number<8>{}),
make_tuple(number<1>{}, number<4>{}));
std::string output = CapturePrintOutput(embed_transform);
// Verify it contains the embed identifier and key fields
EXPECT_TRUE(output.find("embed{") == 0);
EXPECT_TRUE(output.find("up_lengths_") != std::string::npos);
EXPECT_TRUE(output.find("coefficients_") != std::string::npos);
EXPECT_TRUE(output.back() == '}');
}
TEST_F(PrintCoordinateTransformTest, PrintMerge)
{
// Test printing merge transform
auto merge_transform = make_merge_transform(make_tuple(number<4>{}, number<8>{}));
std::string output = CapturePrintOutput(merge_transform);
// Verify it contains merge identifier and key fields
EXPECT_TRUE(output.find("merge") ==
0); // Could be merge_v2_magic_division or merge_v3_division_mod
EXPECT_TRUE(output.find("low_lengths_") != std::string::npos ||
output.find("up_lengths_") != std::string::npos);
EXPECT_TRUE(output.back() == '}');
}
TEST_F(PrintCoordinateTransformTest, PrintUnmerge)
{
// Test printing unmerge transform
auto unmerge_transform = make_unmerge_transform(make_tuple(number<4>{}, number<8>{}));
std::string output = CapturePrintOutput(unmerge_transform);
// Verify it contains the unmerge identifier and key fields
EXPECT_TRUE(output.find("unmerge{") == 0);
EXPECT_TRUE(output.find("up_lengths_") != std::string::npos);
EXPECT_TRUE(output.back() == '}');
}
TEST_F(PrintCoordinateTransformTest, PrintFreeze)
{
// Test printing freeze transform
auto freeze_transform = make_freeze_transform(number<5>{});
std::string output = CapturePrintOutput(freeze_transform);
// Verify it contains the freeze identifier and key fields
EXPECT_TRUE(output.find("freeze{") == 0);
EXPECT_TRUE(output.find("low_idx_") != std::string::npos);
EXPECT_TRUE(output.back() == '}');
}
} // namespace ck_tile

View File

@@ -0,0 +1,45 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_print_common.hpp"
#include "ck_tile/core/utility/print.hpp"
#include "ck_tile/core/container/sequence.hpp"
namespace ck_tile {
class PrintSequenceTest : public PrintTest
{
};
TEST_F(PrintSequenceTest, PrintSimpleSequence)
{
// Test printing sequence<1, 5, 8>
constexpr auto seq = sequence<1, 5, 8>{};
std::string output = CapturePrintOutput(seq);
// Verify the output format
EXPECT_EQ(output, "sequence<1, 5, 8>");
}
TEST_F(PrintSequenceTest, PrintSingleElementSequence)
{
// Test printing sequence<42>
constexpr auto seq = sequence<42>{};
std::string output = CapturePrintOutput(seq);
EXPECT_EQ(output, "sequence<42>");
}
TEST_F(PrintSequenceTest, PrintEmptySequence)
{
// Test printing sequence<> (empty sequence)
constexpr auto seq = sequence<>{};
std::string output = CapturePrintOutput(seq);
EXPECT_EQ(output, "sequence<>");
}
} // namespace ck_tile

View File

@@ -0,0 +1,89 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_print_common.hpp"
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
#include "ck_tile/core/utility/print.hpp"
#include <sstream>
namespace ck_tile {
class PrintStaticEncodingPatternTest : public PrintTest
{
protected:
void TestY0Y1Y2(const std::string& output, auto Y0, auto Y1, auto Y2)
{
std::stringstream expected;
expected << "<Y0, Y1, Y2>: <" << Y0 << ", " << Y1 << ", " << Y2 << ">";
EXPECT_TRUE(output.find(expected.str()) != std::string::npos);
}
void TestX0X1(const std::string& output, auto X0, auto X1)
{
std::stringstream expected;
expected << "<X0, X1>: <" << X0 << ", " << X1 << ">";
EXPECT_TRUE(output.find(expected.str()) != std::string::npos);
}
};
TEST_F(PrintStaticEncodingPatternTest, PrintThreadRakedPattern)
{
// Test printing thread raked pattern
using PatternType =
TileDistributionEncodingPattern2D<64, 8, 16, 4, tile_distribution_pattern::thread_raked>;
PatternType pattern;
std::string output = CapturePrintOutput(pattern);
// Verify the output contains expected information
EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos);
EXPECT_TRUE(output.find("BlockSize:64") != std::string::npos);
EXPECT_TRUE(output.find("YPerTile:8") != std::string::npos);
EXPECT_TRUE(output.find("XPerTile:16") != std::string::npos);
EXPECT_TRUE(output.find("VecSize:4") != std::string::npos);
EXPECT_TRUE(output.find("thread_raked") != std::string::npos);
TestY0Y1Y2(output, PatternType::Y0, PatternType::Y1, PatternType::Y2);
TestX0X1(output, PatternType::X0, PatternType::X1);
}
TEST_F(PrintStaticEncodingPatternTest, PrintWarpRakedPattern)
{
// Test printing warp raked pattern
using PatternType =
TileDistributionEncodingPattern2D<128, 16, 32, 8, tile_distribution_pattern::warp_raked>;
PatternType pattern;
std::string output = CapturePrintOutput(pattern);
// Verify the output contains expected information
EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos);
EXPECT_TRUE(output.find("BlockSize:128") != std::string::npos);
EXPECT_TRUE(output.find("YPerTile:16") != std::string::npos);
EXPECT_TRUE(output.find("XPerTile:32") != std::string::npos);
EXPECT_TRUE(output.find("VecSize:8") != std::string::npos);
EXPECT_TRUE(output.find("warp_raked") != std::string::npos);
TestY0Y1Y2(output, PatternType::Y0, PatternType::Y1, PatternType::Y2);
TestX0X1(output, PatternType::X0, PatternType::X1);
}
TEST_F(PrintStaticEncodingPatternTest, PrintBlockRakedPattern)
{
// Test printing block raked pattern
using PatternType =
TileDistributionEncodingPattern2D<256, 32, 64, 16, tile_distribution_pattern::block_raked>;
PatternType pattern;
std::string output = CapturePrintOutput(pattern);
// Verify the output contains expected information
EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos);
EXPECT_TRUE(output.find("BlockSize:256") != std::string::npos);
EXPECT_TRUE(output.find("YPerTile:32") != std::string::npos);
EXPECT_TRUE(output.find("XPerTile:64") != std::string::npos);
EXPECT_TRUE(output.find("VecSize:16") != std::string::npos);
EXPECT_TRUE(output.find("block_raked") != std::string::npos);
TestY0Y1Y2(output, PatternType::Y0, PatternType::Y1, PatternType::Y2);
TestX0X1(output, PatternType::X0, PatternType::X1);
}
} // namespace ck_tile

View File

@@ -0,0 +1,66 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_print_common.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
class PrintTupleTest : public PrintTest
{
};
TEST_F(PrintTupleTest, PrintSimpleTuple)
{
// Test printing tuple with numbers
auto tup = make_tuple(number<1>{}, number<5>{}, number<8>{});
std::string output = CapturePrintOutput(tup);
// Verify the output format matches tuple print implementation
EXPECT_TRUE(output.find("tuple<") == 0);
EXPECT_TRUE(output.find("1") != std::string::npos);
EXPECT_TRUE(output.find("5") != std::string::npos);
EXPECT_TRUE(output.find("8") != std::string::npos);
EXPECT_TRUE(output.back() == '>');
}
TEST_F(PrintTupleTest, PrintSingleElementTuple)
{
// Test printing tuple with single element
auto tup = make_tuple(number<42>{});
std::string output = CapturePrintOutput(tup);
EXPECT_TRUE(output.find("tuple<") == 0);
EXPECT_TRUE(output.find("42") != std::string::npos);
EXPECT_TRUE(output.back() == '>');
}
TEST_F(PrintTupleTest, PrintEmptyTuple)
{
// Test printing empty tuple
auto tup = make_tuple();
std::string output = CapturePrintOutput(tup);
EXPECT_EQ(output, "tuple<>");
}
TEST_F(PrintTupleTest, PrintMixedTypeTuple)
{
// Test printing tuple with mixed types (numbers and constants)
auto tup = make_tuple(number<10>{}, constant<20>{}, number<30>{});
std::string output = CapturePrintOutput(tup);
EXPECT_TRUE(output.find("tuple<") == 0);
EXPECT_TRUE(output.find("10") != std::string::npos);
EXPECT_TRUE(output.find("20") != std::string::npos);
EXPECT_TRUE(output.find("30") != std::string::npos);
EXPECT_TRUE(output.back() == '>');
}
} // namespace ck_tile