[CK_TILE] Enable printing more structures in CK-Tile (#2443)

* Add more printing to core cktile

* Revert other changes in static encoding pattern

* Refactor to using a free print() function

* Remove loops and print just the containers

* Print tuple with better formatting, fix sequence compilation

* Add some tests for print utility

* Add print utility header

* Print for static_encoding_pattern

* add buffer_view printing

* Align vector_traits

* Fix formatting

* Lower-case enum strings

Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com>

* Remove empty comment lines

* Fix test with lower-case too

* Reduce repeated code in print tests, move helper function closer to type definition, test X&Y

* Add test_print_common.hpp

* add print.hpp in core.hpp

---------

Co-authored-by: Aviral Goel <aviral.goel@amd.com>
Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
Sami Remes
2025-08-07 15:45:27 +03:00
committed by GitHub
parent 21e9983913
commit ffdee5e774
28 changed files with 1211 additions and 531 deletions

View File

@@ -9,6 +9,7 @@
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
@@ -139,20 +140,19 @@ struct pass_through : public base_transform<1, 1>
{
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("pass_through{");
//
printf("up_lengths_:");
print(up_lengths_);
//
printf("}");
}
};
template <typename LowLength>
CK_TILE_HOST_DEVICE static void print(const pass_through<LowLength>& pt)
{
printf("pass_through{");
printf("up_lengths_: ");
print(pt.get_upper_lengths());
printf("}");
}
template <typename LowLength,
typename LeftPadLength,
typename RightPadLength,
@@ -229,29 +229,25 @@ struct pad : public base_transform<1, 1>
ck_tile::is_known_at_compile_time<LeftPadLength>::value &&
ck_tile::is_known_at_compile_time<RightPadLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}");
}
};
template <typename LowLength,
typename LeftPadLength,
typename RightPadLength,
bool SkipIsValidCheck>
CK_TILE_HOST_DEVICE static void
print(const pad<LowLength, LeftPadLength, RightPadLength, SkipIsValidCheck>& p)
{
printf("pad{");
printf("up_lengths_: ");
print(p.up_lengths_);
printf(", left_pad_length_: ");
print(p.left_pad_length_);
printf(", right_pad_length_: ");
print(p.right_pad_length_);
printf("}");
}
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
struct left_pad
{
@@ -330,24 +326,20 @@ struct left_pad
// It's up to runtime to check the padding length should be multiple of vector length
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("left_pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf("}");
}
};
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck>
CK_TILE_HOST_DEVICE static void
print(const left_pad<LowLength, LeftPadLength, SkipIsValidCheck>& lp)
{
printf("left_pad{");
printf("up_lengths_: ");
print(lp.up_lengths_);
printf(", left_pad_length_: ");
print(lp.left_pad_length_);
printf("}");
}
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
struct right_pad : public base_transform<1, 1>
{
@@ -430,24 +422,20 @@ struct right_pad : public base_transform<1, 1>
// It's up to runtime to check the padding length should be multiple of vector length
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("right_pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}");
}
};
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
CK_TILE_HOST_DEVICE static void
print(const right_pad<LowLength, RightPadLength, SkipIsValidCheck>& rp)
{
printf("right_pad{");
printf("up_lengths_: ");
print(rp.up_lengths_);
printf(", right_pad_length_: ");
print(rp.right_pad_length_);
printf("}");
}
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
// UpLengths and Coefficients can be either of the followings:
// 1) Tuple of index_t, which is known at run-time, or
@@ -532,24 +520,19 @@ struct embed : public base_transform<1, UpLengths::size()>
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<Coefficients>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("coefficients_: ");
print(coefficients_);
printf("}");
}
};
template <typename UpLengths, typename Coefficients>
CK_TILE_HOST_DEVICE static void print(const embed<UpLengths, Coefficients>& e)
{
printf("embed{");
printf("up_lengths_: ");
print(e.up_lengths_);
printf(", coefficients_: ");
print(e.coefficients_);
printf("}");
}
template <typename LowLengths>
struct lambda_merge_generate_MagicDivision_calculate_magic_divisor
{
@@ -699,24 +682,19 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("merge_v2_magic_division{");
//
printf("low_lengths_ ");
print(low_lengths_);
printf(", ");
//
printf("up_lengths_ ");
print(up_lengths_);
printf("}");
}
};
template <typename LowLengths>
CK_TILE_HOST_DEVICE static void print(const merge_v2_magic_division<LowLengths>& m)
{
printf("merge_v2_magic_division{");
printf("low_lengths_: ");
print(m.low_lengths_);
printf(", up_lengths_: ");
print(m.up_lengths_);
printf("}");
}
// Implementation of "merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
@@ -830,29 +808,21 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("Merge_v3_direct_division_mod{");
//
printf("low_lengths_ ");
print(low_lengths_);
printf(", ");
//
printf("low_lengths_scan_ ");
print(low_lengths_scan_);
printf(", ");
//
printf("up_lengths_ ");
print(up_lengths_);
printf("}");
}
};
template <typename LowLengths>
CK_TILE_HOST_DEVICE static void print(const merge_v3_division_mod<LowLengths>& m)
{
printf("merge_v3_division_mod{");
printf("low_lengths_: ");
print(m.low_lengths_);
printf(", low_lengths_scan_: ");
print(m.low_lengths_scan_);
printf(", up_lengths_: ");
print(m.up_lengths_);
printf("}");
}
template <typename UpLengths, bool Use24BitIntegerCalculation>
struct unmerge : public base_transform<1, UpLengths::size()>
{
@@ -958,24 +928,19 @@ struct unmerge : public base_transform<1, UpLengths::size()>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("unmerge{");
//
printf("up_lengths_");
print(up_lengths_);
printf(", ");
//
printf("up_lengths_scan_");
print(up_lengths_scan_);
printf("}");
}
};
template <typename UpLengths, bool Use24BitIntegerCalculation>
CK_TILE_HOST_DEVICE static void print(const unmerge<UpLengths, Use24BitIntegerCalculation>& u)
{
printf("unmerge{");
printf("up_lengths_: ");
print(u.up_lengths_);
printf(", up_lengths_scan_: ");
print(u.up_lengths_scan_);
printf("}");
}
template <typename LowerIndex>
struct freeze : public base_transform<1, 0>
{
@@ -1023,19 +988,17 @@ struct freeze : public base_transform<1, 0>
{
return ck_tile::is_known_at_compile_time<LowerIndex>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("freeze{");
//
printf("low_idx_: ");
print(low_idx_);
printf("}");
}
};
template <typename LowerIndex>
CK_TILE_HOST_DEVICE static void print(const freeze<LowerIndex>& f)
{
printf("freeze{");
printf("low_idx_: ");
print(f.low_idx_);
printf("}");
}
// insert a dangling upper dimension without lower dimension
template <typename UpperLength>
struct insert : public base_transform<0, 1>
@@ -1092,18 +1055,17 @@ struct insert : public base_transform<0, 1>
{
return ck_tile::is_known_at_compile_time<UpperLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("insert{");
//
print(up_lengths_);
printf("}");
}
};
template <typename UpperLength>
CK_TILE_HOST_DEVICE static void print(const insert<UpperLength>& i)
{
printf("insert{");
printf("up_lengths_: ");
print(i.up_lengths_);
printf("}");
}
// replicate the original tensor and create a higher dimensional tensor
template <typename UpLengths>
struct replicate : public base_transform<0, UpLengths::size()>
@@ -1152,21 +1114,19 @@ struct replicate : public base_transform<0, UpLengths::size()>
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("replicate{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
}
//
UpLengths up_lengths_;
};
template <typename UpLengths>
CK_TILE_HOST_DEVICE static void print(const replicate<UpLengths>& r)
{
printf("replicate{");
printf("up_lengths_: ");
print(r.up_lengths_);
printf("}");
}
template <typename LowLength, typename SliceBegin, typename SliceEnd>
struct slice : public base_transform<1, 1>
{
@@ -1238,28 +1198,20 @@ struct slice : public base_transform<1, 1>
ck_tile::is_known_at_compile_time<SliceBegin>::value &&
ck_tile::is_known_at_compile_time<SliceEnd>::value;
}
};
CK_TILE_HOST_DEVICE void print() const
{
printf("slice{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("slice_begin_: ");
print(slice_begin_);
printf(", ");
//
printf("slice_end_: ");
print(slice_end_);
printf("}");
} // namespace ck
}; // namespace ck
template <typename LowLength, typename SliceBegin, typename SliceEnd>
CK_TILE_HOST_DEVICE static void print(const slice<LowLength, SliceBegin, SliceEnd>& s)
{
printf("slice{");
printf("up_lengths_: ");
print(s.up_lengths_);
printf(", slice_begin_: ");
print(s.slice_begin_);
printf(", slice_end_: ");
print(s.slice_end_);
printf("}");
}
/*
* \brief lower_idx = upper_idx % modulus.
@@ -1328,19 +1280,19 @@ struct modulo : public base_transform<1, 1>
{
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("Modulus{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
}
};
template <typename Modulus, typename UpLength>
CK_TILE_HOST_DEVICE static void print(const modulo<Modulus, UpLength>& m)
{
printf("modulo{");
printf("modulus_: ");
print(m.modulus_);
printf(", up_lengths_: ");
print(m.up_lengths_);
printf("}");
}
// 2D XOR, NOTE: "xor" is a keyword
template <typename LowLengths>
struct xor_t : public base_transform<2, 2>
@@ -1424,20 +1376,17 @@ struct xor_t : public base_transform<2, 2>
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("xor_t{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
printf("}");
}
};
template <typename LowLengths>
CK_TILE_HOST_DEVICE static void print(const xor_t<LowLengths>& x)
{
printf("xor_t{");
printf("up_lengths_: ");
print(x.up_lengths_);
printf("}");
}
template <typename LowLength, typename OffsetLength>
struct offset : public base_transform<1, 1>
{
@@ -1509,24 +1458,19 @@ struct offset : public base_transform<1, 1>
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<OffsetLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("offset{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("offset_length_: ");
print(offset_length_);
printf("}");
}
};
template <typename LowLength, typename OffsetLength>
CK_TILE_HOST_DEVICE static void print(const offset<LowLength, OffsetLength>& o)
{
printf("offset{");
printf("up_lengths_: ");
print(o.up_lengths_);
printf(", offset_length_: ");
print(o.offset_length_);
printf("}");
}
template <typename UpLength, typename IndexingAdaptor>
struct indexing : public base_transform<1, 1>
{
@@ -1595,20 +1539,19 @@ struct indexing : public base_transform<1, 1>
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
IndexingAdaptor::is_known_at_compile_time();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
printf("}");
}
};
template <typename UpLength, typename IndexingAdaptor>
CK_TILE_HOST_DEVICE static void print(const indexing<UpLength, IndexingAdaptor>& i)
{
printf("indexing{");
printf("up_lengths_: ");
print(i.up_lengths_);
printf(", iadaptor_: ");
print(i.iadaptor_);
printf("}");
}
//*******************************************************************************************************
template <typename LowLength>

View File

@@ -77,6 +77,7 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
@@ -317,4 +318,51 @@ struct TileDistributionEncodingPattern2D<BlockSize,
}
};
// Helper function to convert enum to string
constexpr const char* tile_distribution_pattern_to_string(tile_distribution_pattern pattern)
{
switch(pattern)
{
case tile_distribution_pattern::thread_raked: return "thread_raked";
case tile_distribution_pattern::warp_raked: return "warp_raked";
case tile_distribution_pattern::block_raked: return "block_raked";
default: return "unknown";
}
}
template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
tile_distribution_pattern DistributionPattern,
index_t NumWaveGroups>
CK_TILE_HOST_DEVICE void print(const TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
DistributionPattern,
NumWaveGroups>&)
{
using PatternType = TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
DistributionPattern,
NumWaveGroups>;
printf("TileDistributionEncodingPattern2D<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
"VecSize:%d, %s>: ",
BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern_to_string(DistributionPattern));
printf("{<Y0, Y1, Y2>: <%d, %d, %d>, <X0, X1>: <%d, %d>}\n",
PatternType::Y0,
PatternType::Y1,
PatternType::Y2,
PatternType::X0,
PatternType::X1);
}
} // namespace ck_tile

View File

@@ -218,4 +218,19 @@ CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity()
#endif
}
/// Helper function to convert address space enum to string
CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_enum addr_space)
{
switch(addr_space)
{
case address_space_enum::generic: return "generic";
case address_space_enum::global: return "global";
case address_space_enum::lds: return "lds";
case address_space_enum::sgpr: return "sgpr";
case address_space_enum::constant: return "constant";
case address_space_enum::vgpr: return "vgpr";
default: return "unknown";
}
}
} // namespace ck_tile

View File

@@ -177,9 +177,27 @@ struct array<T, 0>
CK_TILE_HOST_DEVICE constexpr array() {}
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
};
template <typename T, index_t N>
CK_TILE_HOST_DEVICE static void print(const array<T, N>& a)
{
printf("array{size: %ld, data: [", static_cast<long>(N));
for(index_t i = 0; i < N; ++i)
{
if(i > 0)
printf(", ");
print(a[i]);
}
printf("]}");
}
template <typename T>
CK_TILE_HOST_DEVICE static void print(const array<T, 0>&)
{
printf("array{size: 0, data: []}");
}
template <typename, typename>
struct vector_traits;

View File

@@ -139,26 +139,21 @@ struct map
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
CK_TILE_HOST_DEVICE void print() const
{
printf("map{size_: %d, ", size_);
//
printf("impl_: [");
//
for(const auto& [k, d] : *this)
{
printf("{key: ");
print(k);
printf(", data: ");
print(d);
printf("}, ");
}
//
printf("]");
//
printf("}");
}
};
template <typename key, typename data, index_t max_size>
CK_TILE_HOST_DEVICE static void print(const map<key, data, max_size>& m)
{
printf("map{size_: %d, impl_: [", m.size_);
for(const auto& [k, d] : m)
{
printf("{key: ");
print(k);
printf(", data: ");
print(d);
printf("}, ");
}
printf("]}");
}
} // namespace ck_tile

View File

@@ -9,13 +9,10 @@
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/print.hpp"
namespace ck_tile {
template <index_t, index_t, index_t>
struct static_for;
template <index_t...>
struct sequence;
@@ -196,15 +193,24 @@ struct sequence
{
return sequence<f(Is)...>{};
}
CK_TILE_HOST_DEVICE static void print()
{
printf("sequence{size: %d, data: [", size());
((printf("%d ", Is)), ...);
printf("]}");
}
};
template <index_t... Is>
CK_TILE_HOST_DEVICE static void print(const sequence<Is...>&)
{
printf("sequence<");
if constexpr(sizeof...(Is) > 0)
{
bool first = true;
(([&first](index_t value) {
printf("%s%d", first ? "" : ", ", value);
first = false;
}(Is)),
...);
}
printf(">");
}
namespace impl {
template <typename T, T... Ints>
struct __integer_sequence;

View File

@@ -300,12 +300,29 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
#undef TP_COM_
};
template <typename, typename = void>
template <typename... T>
CK_TILE_HOST_DEVICE void print(const tuple<T...>& t)
{
printf("tuple<");
if constexpr(sizeof...(T) > 0)
{
bool first = true;
static_for<0, sizeof...(T), 1>{}([&t, &first](auto i) {
if(!first)
printf(", ");
print(t.get(i));
first = false;
});
}
printf(">");
}
template <typename, typename>
struct vector_traits;
// specialization for array
template <typename... T>
struct vector_traits<tuple<T...>>
struct vector_traits<tuple<T...>, void>
{
using scalar_type = __type_pack_element<0, T...>;
static constexpr index_t vector_size = sizeof...(T);

View File

@@ -19,14 +19,18 @@ struct constant
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
};
template <auto v>
CK_TILE_HOST_DEVICE static void print(const constant<v>&)
{
printf("%ld", static_cast<long>(v));
}
template <typename T, T v>
struct integral_constant : constant<v>
{
using value_type = T;
using type = integral_constant; // using injected-class-name
static constexpr T value = v;
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
};
template <index_t v>

View File

@@ -84,7 +84,7 @@ using ext_vector_t = typename impl::ext_vector<T, N>::type;
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
// ... unless we have other vector_traits specialization
template <typename T, typename>
template <typename T, typename = void>
struct vector_traits
{
using scalar_type =
@@ -94,7 +94,7 @@ struct vector_traits
// specialization for ext_vector_type()
template <typename T, index_t N>
struct vector_traits<T __attribute__((ext_vector_type(N)))>
struct vector_traits<T __attribute__((ext_vector_type(N))), void>
{
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
static constexpr index_t vector_size = N;

View File

@@ -210,28 +210,6 @@ struct buffer_view<address_space_enum::generic,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: generic, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: Global
@@ -757,28 +735,6 @@ struct buffer_view<address_space_enum::global,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Global, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: LDS
@@ -1138,28 +1094,6 @@ struct buffer_view<address_space_enum::lds,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Lds, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: Vgpr
@@ -1313,28 +1247,6 @@ struct buffer_view<address_space_enum::vgpr,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Vgpr, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
template <address_space_enum BufferAddressSpace,
@@ -1360,4 +1272,25 @@ make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value)
p, buffer_size, invalid_element_value};
}
// Generalized print function for all buffer_view variants
template <address_space_enum BufferAddressSpace,
typename T,
typename BufferSizeType,
bool InvalidElementUseNumericalZeroValue,
amd_buffer_coherence_enum Coherence>
CK_TILE_HOST_DEVICE void print(const buffer_view<BufferAddressSpace,
T,
BufferSizeType,
InvalidElementUseNumericalZeroValue,
Coherence>& bv)
{
printf("buffer_view{AddressSpace: %s, p_data_: %p, buffer_size_: ",
address_space_to_string(BufferAddressSpace),
static_cast<void*>(const_cast<remove_cvref_t<T>*>(bv.p_data_)));
print(bv.buffer_size_);
printf(", invalid_element_value_: ");
print(bv.invalid_element_value_);
printf("}");
}
} // namespace ck_tile

View File

@@ -305,42 +305,45 @@ struct tensor_adaptor
get_container_subset(vector_strides, top_dims));
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_adaptor{");
//
printf("transforms: ");
print(transforms_);
printf(", ");
//
printf("LowerDimensionHiddenIds: ");
print(LowerDimensionHiddenIdss{});
printf(", ");
//
printf("UpperDimensionHiddenIds: ");
print(UpperDimensionHiddenIdss{});
printf(", ");
//
printf("BottomDimensionHiddenIds: ");
print(BottomDimensionHiddenIds{});
printf(", ");
//
printf("TopDimensionHiddenIds: ");
print(TopDimensionHiddenIds{});
printf("}");
}
private:
Transforms transforms_;
ElementSize element_size_;
};
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename BottomDimensionHiddenIds,
typename TopDimensionHiddenIds>
CK_TILE_HOST_DEVICE static void print(const tensor_adaptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
BottomDimensionHiddenIds,
TopDimensionHiddenIds>& adaptor)
{
printf("tensor_adaptor{\n");
printf(" transforms: [");
print(adaptor.get_transforms());
printf("],\n");
printf(" LowerDimensionHiddenIds: [");
print(LowerDimensionHiddenIdss{});
printf("],\n");
printf(" UpperDimensionHiddenIds: [");
print(UpperDimensionHiddenIdss{});
printf("],\n");
printf(" BottomDimensionHiddenIds: [");
print(BottomDimensionHiddenIds{});
printf("],\n");
//
printf(" TopDimensionHiddenIds: [");
print(TopDimensionHiddenIds{});
printf("]\n}\n");
}
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>

View File

@@ -140,25 +140,37 @@ struct tensor_descriptor : public tensor_adaptor<Transforms,
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_descriptor{");
// tensor_adaptor
Base::print();
printf(", ");
// element_space_size_
printf("element_space_size_: ");
print(element_space_size_);
printf("}");
}
// TODO make these private
ElementSpaceSize element_space_size_;
};
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename TopDimensionHiddenIds,
typename ElementSpaceSize,
typename GuaranteedVectorLengths,
typename GuaranteedVectorStrides>
CK_TILE_HOST_DEVICE static void print(const tensor_descriptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
TopDimensionHiddenIds,
ElementSpaceSize,
GuaranteedVectorLengths,
GuaranteedVectorStrides>& descriptor)
{
printf("tensor_descriptor{\n");
// first print the tensor adaptor part of the descriptor using the base class print
print(static_cast<const typename decltype(descriptor)::Base&>(descriptor));
printf("element_space_size_: %ld,\n",
static_cast<long>(descriptor.get_element_space_size().value));
printf("guaranteed_vector_lengths: ");
print(GuaranteedVectorLengths{});
printf(",\nguaranteed_vector_strides: ");
print(GuaranteedVectorStrides{});
printf("}\n}\n");
}
template <typename Adaptor, typename ElementSpaceSize>
CK_TILE_HOST_DEVICE constexpr auto
make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,

View File

@@ -228,24 +228,6 @@ struct tile_distribution
{
return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution{");
//
printf("tile_distribution_encoding: ");
print(DstrEncode{});
printf(", ");
//
printf("ps_ys_to_xs_: ");
print(ps_ys_to_xs_);
printf(", ");
//
printf("ys_to_d_: ");
print(ys_to_d_);
//
printf("}");
}
};
namespace detail {
@@ -710,4 +692,27 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
}
} // namespace detail
// Free print function for tile_distribution
template <typename PsYs2XsAdaptor_,
typename Ys2DDescriptor_,
typename StaticTileDistributionEncoding_,
typename TileDistributionDetail_>
CK_TILE_HOST_DEVICE void print(const tile_distribution<PsYs2XsAdaptor_,
Ys2DDescriptor_,
StaticTileDistributionEncoding_,
TileDistributionDetail_>& distribution)
{
printf("tile_distribution{");
printf("tile_distribution_encoding: ");
print(StaticTileDistributionEncoding_{});
printf(", ");
printf("ps_ys_to_xs_: ");
print(distribution.ps_ys_to_xs_);
printf(", ");
printf("ys_to_d_: ");
print(distribution.ys_to_d_);
printf("}\n");
}
} // namespace ck_tile

View File

@@ -428,109 +428,7 @@ struct tile_distribution_encoding
{
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution_encoding::detail{");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("ndim_span_major_: ");
print(ndim_span_major_);
printf(", ");
//
printf("ndims_rhs_minor_: ");
print(ndims_rhs_minor_);
printf(", ");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("max_ndim_rh_minor_: ");
print(max_ndim_rh_minor_);
printf(", ");
//
printf("rhs_lengthss_: ");
print(rhs_lengthss_);
printf(", ");
//
printf("ys_lengths_: ");
print(ys_lengths_);
printf(", ");
//
printf("rhs_major_minor_to_ys_: ");
print(rhs_major_minor_to_ys_);
printf(", ");
//
printf("ndims_span_minor_: ");
print(ndims_span_minor_);
printf(", ");
//
printf("max_ndim_span_minor_: ");
print(max_ndim_span_minor_);
printf(", ");
//
printf("ys_to_span_major_: ");
print(ys_to_span_major_);
printf(", ");
//
printf("ys_to_span_minor_: ");
print(ys_to_span_minor_);
printf(", ");
//
printf("distributed_spans_lengthss_: ");
print(distributed_spans_lengthss_);
printf(", ");
//
printf("ndims_distributed_spans_minor_: ");
print(ndims_distributed_spans_minor_);
printf(", ");
//
printf("ps_over_rs_derivative_: ");
print(ps_over_rs_derivative_);
//
printf("}");
}
};
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution_encoding{");
//
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
//
printf("rs_lengths_: ");
print(rs_lengths_);
printf(", ");
//
printf("hs_lengthss_: ");
print(hs_lengthss_);
printf(", ");
//
printf("ps_to_rhss_major_: ");
print(ps_to_rhss_major_);
printf(", ");
//
printf("ps_to_rhss_minor_: ");
print(ps_to_rhss_minor_);
printf(", ");
//
printf("ys_to_rhs_major_: ");
print(ys_to_rhs_major_);
printf(", ");
//
printf("ys_to_rhs_minor_: ");
print(ys_to_rhs_minor_);
printf(", ");
//
printf("detail: ");
print(detail{});
//
printf("}");
}
};
template <typename encoding, typename shuffle>
@@ -896,4 +794,106 @@ make_reduce_tile_distribution_encoding(InDstr, sequence<InReduceDimXs...> reduce
}
} // namespace detail
// Free print function for tile_distribution_encoding::detail
template <typename RsLengths_,
typename HsLengthss_,
typename Ps2RHssMajor_,
typename Ps2RHssMinor_,
typename Ys2RHsMajor_,
typename Ys2RHsMinor_>
CK_TILE_HOST_DEVICE void
print(const typename tile_distribution_encoding<RsLengths_,
HsLengthss_,
Ps2RHssMajor_,
Ps2RHssMinor_,
Ys2RHsMajor_,
Ys2RHsMinor_>::detail& detail_obj)
{
printf("tile_distribution_encoding::detail{");
printf("ndim_rh_major_: ");
print(detail_obj.ndim_rh_major_);
printf(", ");
printf("ndim_span_major_: ");
print(detail_obj.ndim_span_major_);
printf(", ");
printf("ndims_rhs_minor_: ");
print(detail_obj.ndims_rhs_minor_);
printf(", ");
printf("ndim_rh_major_: ");
print(detail_obj.ndim_rh_major_);
printf(", ");
printf("max_ndim_rh_minor_: ");
print(detail_obj.max_ndim_rh_minor_);
printf(", ");
printf("rhs_lengthss_: ");
print(detail_obj.rhs_lengthss_);
printf(", ");
printf("ys_lengths_: ");
print(detail_obj.ys_lengths_);
printf(", ");
printf("rhs_major_minor_to_ys_: ");
print(detail_obj.rhs_major_minor_to_ys_);
printf(", ");
printf("ndims_span_minor_: ");
print(detail_obj.ndims_span_minor_);
printf(", ");
printf("max_ndim_span_minor_: ");
print(detail_obj.max_ndim_span_minor_);
printf(", ");
printf("ys_to_span_major_: ");
print(detail_obj.ys_to_span_major_);
printf(", ");
printf("ys_to_span_minor_: ");
print(detail_obj.ys_to_span_minor_);
printf(", ");
printf("distributed_spans_lengthss_: ");
print(detail_obj.distributed_spans_lengthss_);
printf(", ");
printf("ndims_distributed_spans_minor_: ");
print(detail_obj.ndims_distributed_spans_minor_);
printf(", ");
printf("ps_over_rs_derivative_: ");
print(detail_obj.ps_over_rs_derivative_);
printf("}");
}
// Free print function for tile_distribution_encoding
template <typename RsLengths_,
typename HsLengthss_,
typename Ps2RHssMajor_,
typename Ps2RHssMinor_,
typename Ys2RHsMajor_,
typename Ys2RHsMinor_>
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding<RsLengths_,
HsLengthss_,
Ps2RHssMajor_,
Ps2RHssMinor_,
Ys2RHsMajor_,
Ys2RHsMinor_>& encoding)
{
printf("tile_distribution_encoding{");
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", encoding.NDimX, encoding.NDimP, encoding.NDimY);
printf("rs_lengths_: ");
print(encoding.rs_lengths_);
printf(", ");
printf("hs_lengthss_: ");
print(encoding.hs_lengthss_);
printf(", ");
printf("ps_to_rhss_major_: ");
print(encoding.ps_to_rhss_major_);
printf(", ");
printf("ps_to_rhss_minor_: ");
print(encoding.ps_to_rhss_minor_);
printf(", ");
printf("ys_to_rhs_major_: ");
print(encoding.ys_to_rhs_major_);
printf(", ");
printf("ys_to_rhs_minor_: ");
print(encoding.ys_to_rhs_minor_);
printf(", ");
printf("}");
}
} // namespace ck_tile

View File

@@ -0,0 +1,76 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
/// Declare a ck_tile::print() interface that gets specialized in each header file for types that
/// can be printed.
template <typename T>
CK_TILE_HOST_DEVICE void print(const T&)
{
static_assert(sizeof(T) == 0,
"No print implementation available for this type. Please specialize "
"ck_tile::print for your type.");
}
/// Specialization for int
template <>
CK_TILE_HOST_DEVICE void print(const int& value)
{
printf("%d", value);
}
/// Specialization for float
template <>
CK_TILE_HOST_DEVICE void print(const float& value)
{
printf("%f", value);
}
/// Specialization for double
template <>
CK_TILE_HOST_DEVICE void print(const double& value)
{
printf("%f", value);
}
/// Specialization for long
template <>
CK_TILE_HOST_DEVICE void print(const long& value)
{
printf("%ld", value);
}
/// Specialization for unsigned int
template <>
CK_TILE_HOST_DEVICE void print(const unsigned int& value)
{
printf("%u", value);
}
/// Specialization for char
template <>
CK_TILE_HOST_DEVICE void print(const char& value)
{
printf("%c", value);
}
/// Specialization for array
template <typename T, size_t N>
CK_TILE_HOST_DEVICE void print(const T (&value)[N])
{
printf("[");
for(size_t i = 0; i < N; ++i)
{
if(i > 0)
printf(", ");
print(value[i]); // Recursively call print for each element
}
printf("]");
}
} // namespace ck_tile