mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
[CK_TILE] Enable printing more structures in CK-Tile (#2443)
* Add more printing to core cktile
* Revert other changes in static encoding pattern
* Refactor to using a free print() function
* Remove loops and print just the containers
* Print tuple with better formatting, fix sequence compilation
* Add some tests for print utility
* Add print utility header
* Print for static_encoding_pattern
* add buffer_view printing
* Align vector_traits
* Fix formatting
* Lower-case enum strings
Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com>
* Remove empty comment lines
* Fix test with lower-case too
* Reduce repeated code in print tests, move helper function closer to type definition, test X&Y
* Add test_print_common.hpp
* add print.hpp in core.hpp
---------
Co-authored-by: Aviral Goel <aviral.goel@amd.com>
Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
[ROCm/composable_kernel commit: ffdee5e774]
This commit is contained in:
@@ -74,6 +74,7 @@
|
||||
#include "ck_tile/core/utility/literals.hpp"
|
||||
#include "ck_tile/core/utility/magic_div.hpp"
|
||||
#include "ck_tile/core/utility/philox_rand.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/utility/reduce_operator.hpp"
|
||||
#include "ck_tile/core/utility/static_counter.hpp"
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/magic_div.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -139,20 +140,19 @@ struct pass_through : public base_transform<1, 1>
|
||||
{
|
||||
return make_tuple(low_vector_lengths, low_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("pass_through{");
|
||||
|
||||
//
|
||||
printf("up_lengths_:");
|
||||
print(up_lengths_);
|
||||
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength>
|
||||
CK_TILE_HOST_DEVICE static void print(const pass_through<LowLength>& pt)
|
||||
{
|
||||
printf("pass_through{");
|
||||
|
||||
printf("up_lengths_: ");
|
||||
print(pt.get_upper_lengths());
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLength,
|
||||
typename LeftPadLength,
|
||||
typename RightPadLength,
|
||||
@@ -229,29 +229,25 @@ struct pad : public base_transform<1, 1>
|
||||
ck_tile::is_known_at_compile_time<LeftPadLength>::value &&
|
||||
ck_tile::is_known_at_compile_time<RightPadLength>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("pad{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("left_pad_length_: ");
|
||||
print(left_pad_length_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("right_pad_length_: ");
|
||||
print(right_pad_length_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength,
|
||||
typename LeftPadLength,
|
||||
typename RightPadLength,
|
||||
bool SkipIsValidCheck>
|
||||
CK_TILE_HOST_DEVICE static void
|
||||
print(const pad<LowLength, LeftPadLength, RightPadLength, SkipIsValidCheck>& p)
|
||||
{
|
||||
printf("pad{");
|
||||
printf("up_lengths_: ");
|
||||
print(p.up_lengths_);
|
||||
printf(", left_pad_length_: ");
|
||||
print(p.left_pad_length_);
|
||||
printf(", right_pad_length_: ");
|
||||
print(p.right_pad_length_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
|
||||
struct left_pad
|
||||
{
|
||||
@@ -330,24 +326,20 @@ struct left_pad
|
||||
// It's up to runtime to check the padding length should be multiple of vector length
|
||||
return make_tuple(low_vector_lengths, low_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("left_pad{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("left_pad_length_: ");
|
||||
print(left_pad_length_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck>
|
||||
CK_TILE_HOST_DEVICE static void
|
||||
print(const left_pad<LowLength, LeftPadLength, SkipIsValidCheck>& lp)
|
||||
{
|
||||
printf("left_pad{");
|
||||
printf("up_lengths_: ");
|
||||
print(lp.up_lengths_);
|
||||
printf(", left_pad_length_: ");
|
||||
print(lp.left_pad_length_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
|
||||
struct right_pad : public base_transform<1, 1>
|
||||
{
|
||||
@@ -430,24 +422,20 @@ struct right_pad : public base_transform<1, 1>
|
||||
// It's up to runtime to check the padding length should be multiple of vector length
|
||||
return make_tuple(low_vector_lengths, low_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("right_pad{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("right_pad_length_: ");
|
||||
print(right_pad_length_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
|
||||
CK_TILE_HOST_DEVICE static void
|
||||
print(const right_pad<LowLength, RightPadLength, SkipIsValidCheck>& rp)
|
||||
{
|
||||
printf("right_pad{");
|
||||
printf("up_lengths_: ");
|
||||
print(rp.up_lengths_);
|
||||
printf(", right_pad_length_: ");
|
||||
print(rp.right_pad_length_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
|
||||
// UpLengths and Coefficients can be either of the followings:
|
||||
// 1) Tuple of index_t, which is known at run-time, or
|
||||
@@ -532,24 +520,19 @@ struct embed : public base_transform<1, UpLengths::size()>
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
|
||||
ck_tile::is_known_at_compile_time<Coefficients>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("embed{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("coefficients_: ");
|
||||
print(coefficients_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpLengths, typename Coefficients>
|
||||
CK_TILE_HOST_DEVICE static void print(const embed<UpLengths, Coefficients>& e)
|
||||
{
|
||||
printf("embed{");
|
||||
printf("up_lengths_: ");
|
||||
print(e.up_lengths_);
|
||||
printf(", coefficients_: ");
|
||||
print(e.coefficients_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
struct lambda_merge_generate_MagicDivision_calculate_magic_divisor
|
||||
{
|
||||
@@ -699,24 +682,19 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
|
||||
|
||||
return make_tuple(up_vector_lengths, up_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("merge_v2_magic_division{");
|
||||
|
||||
//
|
||||
printf("low_lengths_ ");
|
||||
print(low_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("up_lengths_ ");
|
||||
print(up_lengths_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLengths>
|
||||
CK_TILE_HOST_DEVICE static void print(const merge_v2_magic_division<LowLengths>& m)
|
||||
{
|
||||
printf("merge_v2_magic_division{");
|
||||
printf("low_lengths_: ");
|
||||
print(m.low_lengths_);
|
||||
printf(", up_lengths_: ");
|
||||
print(m.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// Implementation of "merge" transformation primitive that uses division and mod. It is supposed to
|
||||
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
|
||||
// will be very bad
|
||||
@@ -830,29 +808,21 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
|
||||
|
||||
return make_tuple(up_vector_lengths, up_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("Merge_v3_direct_division_mod{");
|
||||
|
||||
//
|
||||
printf("low_lengths_ ");
|
||||
print(low_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("low_lengths_scan_ ");
|
||||
print(low_lengths_scan_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("up_lengths_ ");
|
||||
print(up_lengths_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLengths>
|
||||
CK_TILE_HOST_DEVICE static void print(const merge_v3_division_mod<LowLengths>& m)
|
||||
{
|
||||
printf("merge_v3_division_mod{");
|
||||
printf("low_lengths_: ");
|
||||
print(m.low_lengths_);
|
||||
printf(", low_lengths_scan_: ");
|
||||
print(m.low_lengths_scan_);
|
||||
printf(", up_lengths_: ");
|
||||
print(m.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation>
|
||||
struct unmerge : public base_transform<1, UpLengths::size()>
|
||||
{
|
||||
@@ -958,24 +928,19 @@ struct unmerge : public base_transform<1, UpLengths::size()>
|
||||
|
||||
return make_tuple(up_vector_lengths, up_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("unmerge{");
|
||||
|
||||
//
|
||||
printf("up_lengths_");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("up_lengths_scan_");
|
||||
print(up_lengths_scan_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation>
|
||||
CK_TILE_HOST_DEVICE static void print(const unmerge<UpLengths, Use24BitIntegerCalculation>& u)
|
||||
{
|
||||
printf("unmerge{");
|
||||
printf("up_lengths_: ");
|
||||
print(u.up_lengths_);
|
||||
printf(", up_lengths_scan_: ");
|
||||
print(u.up_lengths_scan_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowerIndex>
|
||||
struct freeze : public base_transform<1, 0>
|
||||
{
|
||||
@@ -1023,19 +988,17 @@ struct freeze : public base_transform<1, 0>
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<LowerIndex>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("freeze{");
|
||||
|
||||
//
|
||||
printf("low_idx_: ");
|
||||
print(low_idx_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowerIndex>
|
||||
CK_TILE_HOST_DEVICE static void print(const freeze<LowerIndex>& f)
|
||||
{
|
||||
printf("freeze{");
|
||||
printf("low_idx_: ");
|
||||
print(f.low_idx_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// insert a dangling upper dimension without lower dimension
|
||||
template <typename UpperLength>
|
||||
struct insert : public base_transform<0, 1>
|
||||
@@ -1092,18 +1055,17 @@ struct insert : public base_transform<0, 1>
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<UpperLength>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("insert{");
|
||||
|
||||
//
|
||||
print(up_lengths_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpperLength>
|
||||
CK_TILE_HOST_DEVICE static void print(const insert<UpperLength>& i)
|
||||
{
|
||||
printf("insert{");
|
||||
printf("up_lengths_: ");
|
||||
print(i.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// replicate the original tensor and create a higher dimensional tensor
|
||||
template <typename UpLengths>
|
||||
struct replicate : public base_transform<0, UpLengths::size()>
|
||||
@@ -1152,21 +1114,19 @@ struct replicate : public base_transform<0, UpLengths::size()>
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("replicate{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
//
|
||||
UpLengths up_lengths_;
|
||||
};
|
||||
|
||||
template <typename UpLengths>
|
||||
CK_TILE_HOST_DEVICE static void print(const replicate<UpLengths>& r)
|
||||
{
|
||||
printf("replicate{");
|
||||
printf("up_lengths_: ");
|
||||
print(r.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLength, typename SliceBegin, typename SliceEnd>
|
||||
struct slice : public base_transform<1, 1>
|
||||
{
|
||||
@@ -1238,28 +1198,20 @@ struct slice : public base_transform<1, 1>
|
||||
ck_tile::is_known_at_compile_time<SliceBegin>::value &&
|
||||
ck_tile::is_known_at_compile_time<SliceEnd>::value;
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("slice{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("slice_begin_: ");
|
||||
print(slice_begin_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("slice_end_: ");
|
||||
print(slice_end_);
|
||||
|
||||
printf("}");
|
||||
} // namespace ck
|
||||
}; // namespace ck
|
||||
template <typename LowLength, typename SliceBegin, typename SliceEnd>
|
||||
CK_TILE_HOST_DEVICE static void print(const slice<LowLength, SliceBegin, SliceEnd>& s)
|
||||
{
|
||||
printf("slice{");
|
||||
printf("up_lengths_: ");
|
||||
print(s.up_lengths_);
|
||||
printf(", slice_begin_: ");
|
||||
print(s.slice_begin_);
|
||||
printf(", slice_end_: ");
|
||||
print(s.slice_end_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
/*
|
||||
* \brief lower_idx = upper_idx % modulus.
|
||||
@@ -1328,19 +1280,19 @@ struct modulo : public base_transform<1, 1>
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("Modulus{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Modulus, typename UpLength>
|
||||
CK_TILE_HOST_DEVICE static void print(const modulo<Modulus, UpLength>& m)
|
||||
{
|
||||
printf("modulo{");
|
||||
printf("modulus_: ");
|
||||
print(m.modulus_);
|
||||
printf(", up_lengths_: ");
|
||||
print(m.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// 2D XOR, NOTE: "xor" is a keyword
|
||||
template <typename LowLengths>
|
||||
struct xor_t : public base_transform<2, 2>
|
||||
@@ -1424,20 +1376,17 @@ struct xor_t : public base_transform<2, 2>
|
||||
|
||||
return make_tuple(up_vector_lengths, up_vector_strides);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("xor_t{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLengths>
|
||||
CK_TILE_HOST_DEVICE static void print(const xor_t<LowLengths>& x)
|
||||
{
|
||||
printf("xor_t{");
|
||||
printf("up_lengths_: ");
|
||||
print(x.up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename LowLength, typename OffsetLength>
|
||||
struct offset : public base_transform<1, 1>
|
||||
{
|
||||
@@ -1509,24 +1458,19 @@ struct offset : public base_transform<1, 1>
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
|
||||
ck_tile::is_known_at_compile_time<OffsetLength>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("offset{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("offset_length_: ");
|
||||
print(offset_length_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength, typename OffsetLength>
|
||||
CK_TILE_HOST_DEVICE static void print(const offset<LowLength, OffsetLength>& o)
|
||||
{
|
||||
printf("offset{");
|
||||
printf("up_lengths_: ");
|
||||
print(o.up_lengths_);
|
||||
printf(", offset_length_: ");
|
||||
print(o.offset_length_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename UpLength, typename IndexingAdaptor>
|
||||
struct indexing : public base_transform<1, 1>
|
||||
{
|
||||
@@ -1595,20 +1539,19 @@ struct indexing : public base_transform<1, 1>
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
|
||||
IndexingAdaptor::is_known_at_compile_time();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("embed{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpLength, typename IndexingAdaptor>
|
||||
CK_TILE_HOST_DEVICE static void print(const indexing<UpLength, IndexingAdaptor>& i)
|
||||
{
|
||||
printf("indexing{");
|
||||
printf("up_lengths_: ");
|
||||
print(i.up_lengths_);
|
||||
printf(", iadaptor_: ");
|
||||
print(i.iadaptor_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
//*******************************************************************************************************
|
||||
|
||||
template <typename LowLength>
|
||||
|
||||
@@ -77,6 +77,7 @@
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -317,4 +318,51 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
}
|
||||
};
|
||||
|
||||
// Helper function to convert enum to string
|
||||
constexpr const char* tile_distribution_pattern_to_string(tile_distribution_pattern pattern)
|
||||
{
|
||||
switch(pattern)
|
||||
{
|
||||
case tile_distribution_pattern::thread_raked: return "thread_raked";
|
||||
case tile_distribution_pattern::warp_raked: return "warp_raked";
|
||||
case tile_distribution_pattern::block_raked: return "block_raked";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
tile_distribution_pattern DistributionPattern,
|
||||
index_t NumWaveGroups>
|
||||
CK_TILE_HOST_DEVICE void print(const TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
DistributionPattern,
|
||||
NumWaveGroups>&)
|
||||
{
|
||||
using PatternType = TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
DistributionPattern,
|
||||
NumWaveGroups>;
|
||||
|
||||
printf("TileDistributionEncodingPattern2D<BlockSize:%d, YPerTile:%d, XPerTile:%d, "
|
||||
"VecSize:%d, %s>: ",
|
||||
BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern_to_string(DistributionPattern));
|
||||
printf("{<Y0, Y1, Y2>: <%d, %d, %d>, <X0, X1>: <%d, %d>}\n",
|
||||
PatternType::Y0,
|
||||
PatternType::Y1,
|
||||
PatternType::Y2,
|
||||
PatternType::X0,
|
||||
PatternType::X1);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -218,4 +218,19 @@ CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity()
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Helper function to convert address space enum to string
|
||||
CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_enum addr_space)
|
||||
{
|
||||
switch(addr_space)
|
||||
{
|
||||
case address_space_enum::generic: return "generic";
|
||||
case address_space_enum::global: return "global";
|
||||
case address_space_enum::lds: return "lds";
|
||||
case address_space_enum::sgpr: return "sgpr";
|
||||
case address_space_enum::constant: return "constant";
|
||||
case address_space_enum::vgpr: return "vgpr";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -177,9 +177,27 @@ struct array<T, 0>
|
||||
CK_TILE_HOST_DEVICE constexpr array() {}
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
|
||||
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_HOST_DEVICE static void print(const array<T, N>& a)
|
||||
{
|
||||
printf("array{size: %ld, data: [", static_cast<long>(N));
|
||||
for(index_t i = 0; i < N; ++i)
|
||||
{
|
||||
if(i > 0)
|
||||
printf(", ");
|
||||
print(a[i]);
|
||||
}
|
||||
printf("]}");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static void print(const array<T, 0>&)
|
||||
{
|
||||
printf("array{size: 0, data: []}");
|
||||
}
|
||||
|
||||
template <typename, typename>
|
||||
struct vector_traits;
|
||||
|
||||
|
||||
@@ -139,26 +139,21 @@ struct map
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("map{size_: %d, ", size_);
|
||||
//
|
||||
printf("impl_: [");
|
||||
//
|
||||
for(const auto& [k, d] : *this)
|
||||
{
|
||||
printf("{key: ");
|
||||
print(k);
|
||||
printf(", data: ");
|
||||
print(d);
|
||||
printf("}, ");
|
||||
}
|
||||
//
|
||||
printf("]");
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename key, typename data, index_t max_size>
|
||||
CK_TILE_HOST_DEVICE static void print(const map<key, data, max_size>& m)
|
||||
{
|
||||
printf("map{size_: %d, impl_: [", m.size_);
|
||||
for(const auto& [k, d] : m)
|
||||
{
|
||||
printf("{key: ");
|
||||
print(k);
|
||||
printf(", data: ");
|
||||
print(d);
|
||||
printf("}, ");
|
||||
}
|
||||
printf("]}");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -9,13 +9,10 @@
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t, index_t, index_t>
|
||||
struct static_for;
|
||||
|
||||
template <index_t...>
|
||||
struct sequence;
|
||||
|
||||
@@ -196,15 +193,24 @@ struct sequence
|
||||
{
|
||||
return sequence<f(Is)...>{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static void print()
|
||||
{
|
||||
printf("sequence{size: %d, data: [", size());
|
||||
((printf("%d ", Is)), ...);
|
||||
printf("]}");
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE static void print(const sequence<Is...>&)
|
||||
{
|
||||
printf("sequence<");
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
bool first = true;
|
||||
(([&first](index_t value) {
|
||||
printf("%s%d", first ? "" : ", ", value);
|
||||
first = false;
|
||||
}(Is)),
|
||||
...);
|
||||
}
|
||||
printf(">");
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
template <typename T, T... Ints>
|
||||
struct __integer_sequence;
|
||||
|
||||
@@ -300,12 +300,29 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
#undef TP_COM_
|
||||
};
|
||||
|
||||
template <typename, typename = void>
|
||||
template <typename... T>
|
||||
CK_TILE_HOST_DEVICE void print(const tuple<T...>& t)
|
||||
{
|
||||
printf("tuple<");
|
||||
if constexpr(sizeof...(T) > 0)
|
||||
{
|
||||
bool first = true;
|
||||
static_for<0, sizeof...(T), 1>{}([&t, &first](auto i) {
|
||||
if(!first)
|
||||
printf(", ");
|
||||
print(t.get(i));
|
||||
first = false;
|
||||
});
|
||||
}
|
||||
printf(">");
|
||||
}
|
||||
|
||||
template <typename, typename>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
template <typename... T>
|
||||
struct vector_traits<tuple<T...>>
|
||||
struct vector_traits<tuple<T...>, void>
|
||||
{
|
||||
using scalar_type = __type_pack_element<0, T...>;
|
||||
static constexpr index_t vector_size = sizeof...(T);
|
||||
|
||||
@@ -19,14 +19,18 @@ struct constant
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
template <auto v>
|
||||
CK_TILE_HOST_DEVICE static void print(const constant<v>&)
|
||||
{
|
||||
printf("%ld", static_cast<long>(v));
|
||||
}
|
||||
|
||||
template <typename T, T v>
|
||||
struct integral_constant : constant<v>
|
||||
{
|
||||
using value_type = T;
|
||||
using type = integral_constant; // using injected-class-name
|
||||
static constexpr T value = v;
|
||||
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
|
||||
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
|
||||
};
|
||||
|
||||
template <index_t v>
|
||||
|
||||
@@ -84,7 +84,7 @@ using ext_vector_t = typename impl::ext_vector<T, N>::type;
|
||||
|
||||
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
|
||||
// ... unless we have other vector_traits specialization
|
||||
template <typename T, typename>
|
||||
template <typename T, typename = void>
|
||||
struct vector_traits
|
||||
{
|
||||
using scalar_type =
|
||||
@@ -94,7 +94,7 @@ struct vector_traits
|
||||
|
||||
// specialization for ext_vector_type()
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<T __attribute__((ext_vector_type(N)))>
|
||||
struct vector_traits<T __attribute__((ext_vector_type(N))), void>
|
||||
{
|
||||
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
|
||||
static constexpr index_t vector_size = N;
|
||||
|
||||
@@ -210,28 +210,6 @@ struct buffer_view<address_space_enum::generic,
|
||||
|
||||
// FIXME: remove
|
||||
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("buffer_view{");
|
||||
|
||||
// AddressSpace
|
||||
printf("AddressSpace: generic, ");
|
||||
|
||||
// p_data_
|
||||
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
|
||||
|
||||
// buffer_size_
|
||||
printf("buffer_size_: ");
|
||||
print(buffer_size_);
|
||||
printf(", ");
|
||||
|
||||
// invalid_element_value_
|
||||
printf("invalid_element_value_: ");
|
||||
print(invalid_element_value_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
// Address Space: Global
|
||||
@@ -757,28 +735,6 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
// FIXME: remove
|
||||
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("buffer_view{");
|
||||
|
||||
// AddressSpace
|
||||
printf("AddressSpace: Global, ");
|
||||
|
||||
// p_data_
|
||||
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
|
||||
|
||||
// buffer_size_
|
||||
printf("buffer_size_: ");
|
||||
print(buffer_size_);
|
||||
printf(", ");
|
||||
|
||||
// invalid_element_value_
|
||||
printf("invalid_element_value_: ");
|
||||
print(invalid_element_value_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
// Address Space: LDS
|
||||
@@ -1138,28 +1094,6 @@ struct buffer_view<address_space_enum::lds,
|
||||
|
||||
// FIXME: remove
|
||||
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("buffer_view{");
|
||||
|
||||
// AddressSpace
|
||||
printf("AddressSpace: Lds, ");
|
||||
|
||||
// p_data_
|
||||
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
|
||||
|
||||
// buffer_size_
|
||||
printf("buffer_size_: ");
|
||||
print(buffer_size_);
|
||||
printf(", ");
|
||||
|
||||
// invalid_element_value_
|
||||
printf("invalid_element_value_: ");
|
||||
print(invalid_element_value_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
// Address Space: Vgpr
|
||||
@@ -1313,28 +1247,6 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
|
||||
// FIXME: remove
|
||||
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("buffer_view{");
|
||||
|
||||
// AddressSpace
|
||||
printf("AddressSpace: Vgpr, ");
|
||||
|
||||
// p_data_
|
||||
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
|
||||
|
||||
// buffer_size_
|
||||
printf("buffer_size_: ");
|
||||
print(buffer_size_);
|
||||
printf(", ");
|
||||
|
||||
// invalid_element_value_
|
||||
printf("invalid_element_value_: ");
|
||||
print(invalid_element_value_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <address_space_enum BufferAddressSpace,
|
||||
@@ -1360,4 +1272,25 @@ make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value)
|
||||
p, buffer_size, invalid_element_value};
|
||||
}
|
||||
|
||||
// Generalized print function for all buffer_view variants
|
||||
template <address_space_enum BufferAddressSpace,
|
||||
typename T,
|
||||
typename BufferSizeType,
|
||||
bool InvalidElementUseNumericalZeroValue,
|
||||
amd_buffer_coherence_enum Coherence>
|
||||
CK_TILE_HOST_DEVICE void print(const buffer_view<BufferAddressSpace,
|
||||
T,
|
||||
BufferSizeType,
|
||||
InvalidElementUseNumericalZeroValue,
|
||||
Coherence>& bv)
|
||||
{
|
||||
printf("buffer_view{AddressSpace: %s, p_data_: %p, buffer_size_: ",
|
||||
address_space_to_string(BufferAddressSpace),
|
||||
static_cast<void*>(const_cast<remove_cvref_t<T>*>(bv.p_data_)));
|
||||
print(bv.buffer_size_);
|
||||
printf(", invalid_element_value_: ");
|
||||
print(bv.invalid_element_value_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -305,42 +305,45 @@ struct tensor_adaptor
|
||||
get_container_subset(vector_strides, top_dims));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_adaptor{");
|
||||
|
||||
//
|
||||
printf("transforms: ");
|
||||
print(transforms_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("LowerDimensionHiddenIds: ");
|
||||
print(LowerDimensionHiddenIdss{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("UpperDimensionHiddenIds: ");
|
||||
print(UpperDimensionHiddenIdss{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("BottomDimensionHiddenIds: ");
|
||||
print(BottomDimensionHiddenIds{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("TopDimensionHiddenIds: ");
|
||||
print(TopDimensionHiddenIds{});
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
private:
|
||||
Transforms transforms_;
|
||||
ElementSize element_size_;
|
||||
};
|
||||
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename BottomDimensionHiddenIds,
|
||||
typename TopDimensionHiddenIds>
|
||||
CK_TILE_HOST_DEVICE static void print(const tensor_adaptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
BottomDimensionHiddenIds,
|
||||
TopDimensionHiddenIds>& adaptor)
|
||||
{
|
||||
printf("tensor_adaptor{\n");
|
||||
printf(" transforms: [");
|
||||
print(adaptor.get_transforms());
|
||||
printf("],\n");
|
||||
|
||||
printf(" LowerDimensionHiddenIds: [");
|
||||
print(LowerDimensionHiddenIdss{});
|
||||
printf("],\n");
|
||||
|
||||
printf(" UpperDimensionHiddenIds: [");
|
||||
print(UpperDimensionHiddenIdss{});
|
||||
printf("],\n");
|
||||
|
||||
printf(" BottomDimensionHiddenIds: [");
|
||||
print(BottomDimensionHiddenIds{});
|
||||
printf("],\n");
|
||||
|
||||
//
|
||||
printf(" TopDimensionHiddenIds: [");
|
||||
print(TopDimensionHiddenIds{});
|
||||
printf("]\n}\n");
|
||||
}
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
|
||||
|
||||
@@ -140,25 +140,37 @@ struct tensor_descriptor : public tensor_adaptor<Transforms,
|
||||
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_descriptor{");
|
||||
|
||||
// tensor_adaptor
|
||||
Base::print();
|
||||
printf(", ");
|
||||
|
||||
// element_space_size_
|
||||
printf("element_space_size_: ");
|
||||
print(element_space_size_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// TODO make these private
|
||||
ElementSpaceSize element_space_size_;
|
||||
};
|
||||
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename TopDimensionHiddenIds,
|
||||
typename ElementSpaceSize,
|
||||
typename GuaranteedVectorLengths,
|
||||
typename GuaranteedVectorStrides>
|
||||
CK_TILE_HOST_DEVICE static void print(const tensor_descriptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
TopDimensionHiddenIds,
|
||||
ElementSpaceSize,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>& descriptor)
|
||||
{
|
||||
printf("tensor_descriptor{\n");
|
||||
// first print the tensor adaptor part of the descriptor using the base class print
|
||||
print(static_cast<const typename decltype(descriptor)::Base&>(descriptor));
|
||||
printf("element_space_size_: %ld,\n",
|
||||
static_cast<long>(descriptor.get_element_space_size().value));
|
||||
printf("guaranteed_vector_lengths: ");
|
||||
print(GuaranteedVectorLengths{});
|
||||
printf(",\nguaranteed_vector_strides: ");
|
||||
print(GuaranteedVectorStrides{});
|
||||
printf("}\n}\n");
|
||||
}
|
||||
|
||||
template <typename Adaptor, typename ElementSpaceSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,
|
||||
|
||||
@@ -228,24 +228,6 @@ struct tile_distribution
|
||||
{
|
||||
return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution{");
|
||||
//
|
||||
printf("tile_distribution_encoding: ");
|
||||
print(DstrEncode{});
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_ys_to_xs_: ");
|
||||
print(ps_ys_to_xs_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_d_: ");
|
||||
print(ys_to_d_);
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
@@ -710,4 +692,27 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Free print function for tile_distribution
|
||||
template <typename PsYs2XsAdaptor_,
|
||||
typename Ys2DDescriptor_,
|
||||
typename StaticTileDistributionEncoding_,
|
||||
typename TileDistributionDetail_>
|
||||
CK_TILE_HOST_DEVICE void print(const tile_distribution<PsYs2XsAdaptor_,
|
||||
Ys2DDescriptor_,
|
||||
StaticTileDistributionEncoding_,
|
||||
TileDistributionDetail_>& distribution)
|
||||
{
|
||||
printf("tile_distribution{");
|
||||
printf("tile_distribution_encoding: ");
|
||||
print(StaticTileDistributionEncoding_{});
|
||||
printf(", ");
|
||||
printf("ps_ys_to_xs_: ");
|
||||
print(distribution.ps_ys_to_xs_);
|
||||
printf(", ");
|
||||
printf("ys_to_d_: ");
|
||||
print(distribution.ys_to_d_);
|
||||
printf("}\n");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -428,109 +428,7 @@ struct tile_distribution_encoding
|
||||
{
|
||||
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution_encoding::detail{");
|
||||
//
|
||||
printf("ndim_rh_major_: ");
|
||||
print(ndim_rh_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndim_span_major_: ");
|
||||
print(ndim_span_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_rhs_minor_: ");
|
||||
print(ndims_rhs_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndim_rh_major_: ");
|
||||
print(ndim_rh_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("max_ndim_rh_minor_: ");
|
||||
print(max_ndim_rh_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("rhs_lengthss_: ");
|
||||
print(rhs_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_lengths_: ");
|
||||
print(ys_lengths_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("rhs_major_minor_to_ys_: ");
|
||||
print(rhs_major_minor_to_ys_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_span_minor_: ");
|
||||
print(ndims_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("max_ndim_span_minor_: ");
|
||||
print(max_ndim_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_span_major_: ");
|
||||
print(ys_to_span_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_span_minor_: ");
|
||||
print(ys_to_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("distributed_spans_lengthss_: ");
|
||||
print(distributed_spans_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_distributed_spans_minor_: ");
|
||||
print(ndims_distributed_spans_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_over_rs_derivative_: ");
|
||||
print(ps_over_rs_derivative_);
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution_encoding{");
|
||||
//
|
||||
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
|
||||
//
|
||||
printf("rs_lengths_: ");
|
||||
print(rs_lengths_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("hs_lengthss_: ");
|
||||
print(hs_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_to_rhss_major_: ");
|
||||
print(ps_to_rhss_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_to_rhss_minor_: ");
|
||||
print(ps_to_rhss_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_rhs_major_: ");
|
||||
print(ys_to_rhs_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_rhs_minor_: ");
|
||||
print(ys_to_rhs_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("detail: ");
|
||||
print(detail{});
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename encoding, typename shuffle>
|
||||
@@ -896,4 +794,106 @@ make_reduce_tile_distribution_encoding(InDstr, sequence<InReduceDimXs...> reduce
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Free print function for tile_distribution_encoding::detail
|
||||
template <typename RsLengths_,
|
||||
typename HsLengthss_,
|
||||
typename Ps2RHssMajor_,
|
||||
typename Ps2RHssMinor_,
|
||||
typename Ys2RHsMajor_,
|
||||
typename Ys2RHsMinor_>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
print(const typename tile_distribution_encoding<RsLengths_,
|
||||
HsLengthss_,
|
||||
Ps2RHssMajor_,
|
||||
Ps2RHssMinor_,
|
||||
Ys2RHsMajor_,
|
||||
Ys2RHsMinor_>::detail& detail_obj)
|
||||
{
|
||||
printf("tile_distribution_encoding::detail{");
|
||||
printf("ndim_rh_major_: ");
|
||||
print(detail_obj.ndim_rh_major_);
|
||||
printf(", ");
|
||||
printf("ndim_span_major_: ");
|
||||
print(detail_obj.ndim_span_major_);
|
||||
printf(", ");
|
||||
printf("ndims_rhs_minor_: ");
|
||||
print(detail_obj.ndims_rhs_minor_);
|
||||
printf(", ");
|
||||
printf("ndim_rh_major_: ");
|
||||
print(detail_obj.ndim_rh_major_);
|
||||
printf(", ");
|
||||
printf("max_ndim_rh_minor_: ");
|
||||
print(detail_obj.max_ndim_rh_minor_);
|
||||
printf(", ");
|
||||
printf("rhs_lengthss_: ");
|
||||
print(detail_obj.rhs_lengthss_);
|
||||
printf(", ");
|
||||
printf("ys_lengths_: ");
|
||||
print(detail_obj.ys_lengths_);
|
||||
printf(", ");
|
||||
printf("rhs_major_minor_to_ys_: ");
|
||||
print(detail_obj.rhs_major_minor_to_ys_);
|
||||
printf(", ");
|
||||
printf("ndims_span_minor_: ");
|
||||
print(detail_obj.ndims_span_minor_);
|
||||
printf(", ");
|
||||
printf("max_ndim_span_minor_: ");
|
||||
print(detail_obj.max_ndim_span_minor_);
|
||||
printf(", ");
|
||||
printf("ys_to_span_major_: ");
|
||||
print(detail_obj.ys_to_span_major_);
|
||||
printf(", ");
|
||||
printf("ys_to_span_minor_: ");
|
||||
print(detail_obj.ys_to_span_minor_);
|
||||
printf(", ");
|
||||
printf("distributed_spans_lengthss_: ");
|
||||
print(detail_obj.distributed_spans_lengthss_);
|
||||
printf(", ");
|
||||
printf("ndims_distributed_spans_minor_: ");
|
||||
print(detail_obj.ndims_distributed_spans_minor_);
|
||||
printf(", ");
|
||||
printf("ps_over_rs_derivative_: ");
|
||||
print(detail_obj.ps_over_rs_derivative_);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// Free print function for tile_distribution_encoding
|
||||
template <typename RsLengths_,
|
||||
typename HsLengthss_,
|
||||
typename Ps2RHssMajor_,
|
||||
typename Ps2RHssMinor_,
|
||||
typename Ys2RHsMajor_,
|
||||
typename Ys2RHsMinor_>
|
||||
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding<RsLengths_,
|
||||
HsLengthss_,
|
||||
Ps2RHssMajor_,
|
||||
Ps2RHssMinor_,
|
||||
Ys2RHsMajor_,
|
||||
Ys2RHsMinor_>& encoding)
|
||||
{
|
||||
printf("tile_distribution_encoding{");
|
||||
|
||||
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", encoding.NDimX, encoding.NDimP, encoding.NDimY);
|
||||
printf("rs_lengths_: ");
|
||||
print(encoding.rs_lengths_);
|
||||
printf(", ");
|
||||
printf("hs_lengthss_: ");
|
||||
print(encoding.hs_lengthss_);
|
||||
printf(", ");
|
||||
printf("ps_to_rhss_major_: ");
|
||||
print(encoding.ps_to_rhss_major_);
|
||||
printf(", ");
|
||||
printf("ps_to_rhss_minor_: ");
|
||||
print(encoding.ps_to_rhss_minor_);
|
||||
printf(", ");
|
||||
printf("ys_to_rhs_major_: ");
|
||||
print(encoding.ys_to_rhs_major_);
|
||||
printf(", ");
|
||||
printf("ys_to_rhs_minor_: ");
|
||||
print(encoding.ys_to_rhs_minor_);
|
||||
printf(", ");
|
||||
printf("}");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
76
include/ck_tile/core/utility/print.hpp
Normal file
76
include/ck_tile/core/utility/print.hpp
Normal file
@@ -0,0 +1,76 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// Declare a ck_tile::print() interface that gets specialized in each header file for types that
|
||||
/// can be printed.
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void print(const T&)
|
||||
{
|
||||
static_assert(sizeof(T) == 0,
|
||||
"No print implementation available for this type. Please specialize "
|
||||
"ck_tile::print for your type.");
|
||||
}
|
||||
|
||||
/// Specialization for int
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const int& value)
|
||||
{
|
||||
printf("%d", value);
|
||||
}
|
||||
|
||||
/// Specialization for float
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const float& value)
|
||||
{
|
||||
printf("%f", value);
|
||||
}
|
||||
|
||||
/// Specialization for double
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const double& value)
|
||||
{
|
||||
printf("%f", value);
|
||||
}
|
||||
|
||||
/// Specialization for long
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const long& value)
|
||||
{
|
||||
printf("%ld", value);
|
||||
}
|
||||
|
||||
/// Specialization for unsigned int
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const unsigned int& value)
|
||||
{
|
||||
printf("%u", value);
|
||||
}
|
||||
|
||||
/// Specialization for char
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void print(const char& value)
|
||||
{
|
||||
printf("%c", value);
|
||||
}
|
||||
|
||||
/// Specialization for array
|
||||
template <typename T, size_t N>
|
||||
CK_TILE_HOST_DEVICE void print(const T (&value)[N])
|
||||
{
|
||||
printf("[");
|
||||
for(size_t i = 0; i < N; ++i)
|
||||
{
|
||||
if(i > 0)
|
||||
printf(", ");
|
||||
print(value[i]); // Recursively call print for each element
|
||||
}
|
||||
printf("]");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -21,4 +21,5 @@ add_subdirectory(add_rmsnorm2d_rdquant)
|
||||
# add_subdirectory(layernorm2d)
|
||||
# add_subdirectory(rmsnorm2d)
|
||||
add_subdirectory(gemm_block_scale)
|
||||
add_subdirectory(reduce)
|
||||
add_subdirectory(utility)
|
||||
add_subdirectory(reduce)
|
||||
|
||||
4
test/ck_tile/utility/CMakeLists.txt
Normal file
4
test/ck_tile/utility/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
message("-- Adding: test/ck_tile/utility/")
|
||||
|
||||
# Add print tests
|
||||
add_subdirectory(print)
|
||||
8
test/ck_tile/utility/print/CMakeLists.txt
Normal file
8
test/ck_tile/utility/print/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
# Print utility tests
|
||||
add_gtest_executable(test_print_sequence test_print_sequence.cpp)
|
||||
add_gtest_executable(test_print_array test_print_array.cpp)
|
||||
add_gtest_executable(test_print_tuple test_print_tuple.cpp)
|
||||
add_gtest_executable(test_print_coordinate_transform test_print_coordinate_transform.cpp)
|
||||
add_gtest_executable(test_print_static_encoding_pattern test_print_static_encoding_pattern.cpp)
|
||||
add_gtest_executable(test_print_buffer_view test_print_buffer_view.cpp)
|
||||
add_gtest_executable(test_print_basic_types test_print_basic_types.cpp)
|
||||
70
test/ck_tile/utility/print/README.md
Normal file
70
test/ck_tile/utility/print/README.md
Normal file
@@ -0,0 +1,70 @@
|
||||
# Print Function Tests
|
||||
|
||||
This directory contains unit tests for testing the print functionality of various data structures and coordinate transformations in the composable_kernel library.
|
||||
|
||||
## Tests Included
|
||||
|
||||
### test_print_sequence.cpp
|
||||
Tests the print functionality for `sequence<...>` containers:
|
||||
- Simple sequences with multiple elements
|
||||
- Single element sequences
|
||||
- Empty sequences
|
||||
- Longer sequences
|
||||
|
||||
### test_print_array.cpp
|
||||
Tests the print functionality for `array<T, N>` containers:
|
||||
- Arrays with integer values
|
||||
- Single element arrays
|
||||
- Empty arrays (size 0)
|
||||
- Arrays with floating point values
|
||||
|
||||
### test_print_tuple.cpp
|
||||
Tests the print functionality for `tuple<...>` containers:
|
||||
- Simple tuples with numbers
|
||||
- Single element tuples
|
||||
- Empty tuples
|
||||
- Mixed type tuples
|
||||
|
||||
### test_print_coordinate_transform.cpp
|
||||
Tests the print functionality for coordinate transformation structures:
|
||||
- `pass_through` transform
|
||||
- `embed` transform
|
||||
- `merge` transform
|
||||
- `unmerge` transform
|
||||
- `freeze` transform
|
||||
|
||||
## Testing Approach
|
||||
|
||||
All tests use Google Test's `CaptureStdout()` functionality to capture the output from print functions and verify the formatting:
|
||||
|
||||
```cpp
|
||||
testing::internal::CaptureStdout();
|
||||
print(object);
|
||||
std::string output = testing::internal::GetCapturedStdout();
|
||||
EXPECT_EQ(output, "expected_format");
|
||||
```
|
||||
|
||||
This approach enables testing of print function output without affecting the console during test execution.
|
||||
|
||||
## Building and Running
|
||||
|
||||
The tests are integrated into the CMake build system. To build and run the print tests:
|
||||
|
||||
```bash
|
||||
# Build the specific test
|
||||
make test_print_sequence
|
||||
|
||||
# Run the test
|
||||
./test_print_sequence
|
||||
|
||||
# Or run all print tests using CTest
|
||||
ctest -R "test_print"
|
||||
```
|
||||
|
||||
## Adding New Tests
|
||||
|
||||
To add tests for new data structures:
|
||||
|
||||
1. Create a new test file: `test_print_<structure_name>.cpp`
|
||||
2. Follow the existing pattern using `CaptureStdout()`
|
||||
3. Add the test executable to `CMakeLists.txt`
|
||||
59
test/ck_tile/utility/print/test_print_array.cpp
Normal file
59
test/ck_tile/utility/print/test_print_array.cpp
Normal file
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_print_common.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
class PrintArrayTest : public PrintTest
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(PrintArrayTest, PrintIntArray)
|
||||
{
|
||||
// Test printing array<int, 3>
|
||||
array<int, 3> arr{10, 20, 30};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
// The expected format should match the array print function implementation
|
||||
EXPECT_EQ(output, "array{size: 3, data: [10, 20, 30]}");
|
||||
}
|
||||
|
||||
TEST_F(PrintArrayTest, PrintSingleElementArray)
|
||||
{
|
||||
// Test printing array<int, 1>
|
||||
array<int, 1> arr{42};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
EXPECT_EQ(output, "array{size: 1, data: [42]}");
|
||||
}
|
||||
|
||||
TEST_F(PrintArrayTest, PrintEmptyArray)
|
||||
{
|
||||
// Test printing array<int, 0> (empty array)
|
||||
array<int, 0> arr{};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
EXPECT_EQ(output, "array{size: 0, data: []}");
|
||||
}
|
||||
|
||||
TEST_F(PrintArrayTest, PrintFloatArray)
|
||||
{
|
||||
// Test printing array with float values
|
||||
array<float, 2> arr{3.14f, 2.71f};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
// Note: float printing format may vary, so we'll test for basic structure
|
||||
EXPECT_TRUE(output.find("array{size: 2, data: [") == 0);
|
||||
EXPECT_TRUE(output.find("3.14") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("2.71") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("]}") == output.length() - 2);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
76
test/ck_tile/utility/print/test_print_basic_types.cpp
Normal file
76
test/ck_tile/utility/print/test_print_basic_types.cpp
Normal file
@@ -0,0 +1,76 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_print_common.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
class PrintBasicTypesTest : public PrintTest
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(PrintBasicTypesTest, PrintIntArray)
|
||||
{
|
||||
int arr[4] = {1, 2, 3, 4};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
EXPECT_EQ(output, "[1, 2, 3, 4]");
|
||||
}
|
||||
|
||||
TEST_F(PrintBasicTypesTest, PrintFloatArray)
|
||||
{
|
||||
float arr[3] = {1.5f, 2.5f, 3.5f};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
// Note: floating point formatting may vary, so we check for key elements
|
||||
EXPECT_TRUE(output.find("[") == 0);
|
||||
EXPECT_TRUE(output.find("1.5") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("2.5") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("3.5") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == ']');
|
||||
EXPECT_TRUE(output.find(", ") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(PrintBasicTypesTest, PrintDoubleArray)
|
||||
{
|
||||
double arr[2] = {10.123, 20.456};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
EXPECT_TRUE(output.find("[") == 0);
|
||||
EXPECT_TRUE(output.find("10.123") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("20.456") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == ']');
|
||||
}
|
||||
|
||||
TEST_F(PrintBasicTypesTest, PrintUnsignedIntArray)
|
||||
{
|
||||
unsigned int arr[3] = {100u, 200u, 300u};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
EXPECT_EQ(output, "[100, 200, 300]");
|
||||
}
|
||||
|
||||
TEST_F(PrintBasicTypesTest, PrintCharArray)
|
||||
{
|
||||
char arr[5] = {'a', 'b', 'c', 'd', 'e'};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
EXPECT_EQ(output, "[a, b, c, d, e]");
|
||||
}
|
||||
|
||||
TEST_F(PrintBasicTypesTest, PrintSingleElementArray)
|
||||
{
|
||||
int arr[1] = {42};
|
||||
|
||||
std::string output = CapturePrintOutput(arr);
|
||||
|
||||
EXPECT_EQ(output, "[42]");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
78
test/ck_tile/utility/print/test_print_buffer_view.cpp
Normal file
78
test/ck_tile/utility/print/test_print_buffer_view.cpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_print_common.hpp"
|
||||
#include "ck_tile/core/tensor/buffer_view.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
class PrintBufferViewTest : public PrintTest
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(PrintBufferViewTest, PrintGenericBufferView)
|
||||
{
|
||||
// Test printing generic address space buffer_view
|
||||
float data[4] = {100.f, 200.f, 300.f, 400.f};
|
||||
auto bv = make_buffer_view<address_space_enum::generic>(&data, 4);
|
||||
|
||||
std::string output = CapturePrintOutput(bv);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("buffer_view{AddressSpace: generic") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("p_data_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("}") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(PrintBufferViewTest, PrintGlobalBufferView)
|
||||
{
|
||||
// Test printing global address space buffer_view
|
||||
float data[4] = {100.f, 200.f, 300.f, 400.f};
|
||||
auto bv = make_buffer_view<address_space_enum::global>(&data, 4);
|
||||
|
||||
std::string output = CapturePrintOutput(bv);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("buffer_view{AddressSpace: global") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("p_data_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("}") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(PrintBufferViewTest, PrintLdsBufferView)
|
||||
{
|
||||
// Test printing LDS address space buffer_view
|
||||
float data[4] = {100.f, 200.f, 300.f, 400.f};
|
||||
auto bv = make_buffer_view<address_space_enum::lds>(data, 4);
|
||||
|
||||
std::string output = CapturePrintOutput(bv);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("buffer_view{AddressSpace: lds") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("p_data_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("}") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(PrintBufferViewTest, PrintVgprBufferView)
|
||||
{
|
||||
// Test printing VGPR address space buffer_view
|
||||
float data[4] = {1.5f, 2.5f, 3.5f, 4.5f};
|
||||
auto bv = make_buffer_view<address_space_enum::vgpr>(data, 4);
|
||||
|
||||
std::string output = CapturePrintOutput(bv);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("buffer_view{AddressSpace: vgpr") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("p_data_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("buffer_size_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("invalid_element_value_:") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("}") != std::string::npos);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
25
test/ck_tile/utility/print/test_print_common.hpp
Normal file
25
test/ck_tile/utility/print/test_print_common.hpp
Normal file
@@ -0,0 +1,25 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gtest/gtest-spi.h>
|
||||
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
class PrintTest : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
void SetUp() override {}
|
||||
void TearDown() override {}
|
||||
// Helper function to capture and return the output of a print function
|
||||
template <typename T>
|
||||
std::string CapturePrintOutput(const T& type)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
testing::internal::CaptureStdout();
|
||||
print(type);
|
||||
return testing::internal::GetCapturedStdout();
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,83 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_print_common.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
class PrintCoordinateTransformTest : public PrintTest
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(PrintCoordinateTransformTest, PrintPassThrough)
|
||||
{
|
||||
// Test printing pass_through transform
|
||||
auto pt = make_pass_through_transform(number<32>{});
|
||||
|
||||
std::string output = CapturePrintOutput(pt);
|
||||
|
||||
// Verify it contains the pass_through identifier and some structure
|
||||
EXPECT_TRUE(output.find("pass_through{") == 0);
|
||||
EXPECT_TRUE(output.find("up_lengths_") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == '}');
|
||||
}
|
||||
|
||||
TEST_F(PrintCoordinateTransformTest, PrintEmbed)
|
||||
{
|
||||
// Test printing embed transform
|
||||
auto embed_transform = make_embed_transform(make_tuple(number<4>{}, number<8>{}),
|
||||
make_tuple(number<1>{}, number<4>{}));
|
||||
|
||||
std::string output = CapturePrintOutput(embed_transform);
|
||||
|
||||
// Verify it contains the embed identifier and key fields
|
||||
EXPECT_TRUE(output.find("embed{") == 0);
|
||||
EXPECT_TRUE(output.find("up_lengths_") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("coefficients_") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == '}');
|
||||
}
|
||||
|
||||
TEST_F(PrintCoordinateTransformTest, PrintMerge)
|
||||
{
|
||||
// Test printing merge transform
|
||||
auto merge_transform = make_merge_transform(make_tuple(number<4>{}, number<8>{}));
|
||||
|
||||
std::string output = CapturePrintOutput(merge_transform);
|
||||
|
||||
// Verify it contains merge identifier and key fields
|
||||
EXPECT_TRUE(output.find("merge") ==
|
||||
0); // Could be merge_v2_magic_division or merge_v3_division_mod
|
||||
EXPECT_TRUE(output.find("low_lengths_") != std::string::npos ||
|
||||
output.find("up_lengths_") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == '}');
|
||||
}
|
||||
|
||||
TEST_F(PrintCoordinateTransformTest, PrintUnmerge)
|
||||
{
|
||||
// Test printing unmerge transform
|
||||
auto unmerge_transform = make_unmerge_transform(make_tuple(number<4>{}, number<8>{}));
|
||||
|
||||
std::string output = CapturePrintOutput(unmerge_transform);
|
||||
|
||||
// Verify it contains the unmerge identifier and key fields
|
||||
EXPECT_TRUE(output.find("unmerge{") == 0);
|
||||
EXPECT_TRUE(output.find("up_lengths_") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == '}');
|
||||
}
|
||||
|
||||
TEST_F(PrintCoordinateTransformTest, PrintFreeze)
|
||||
{
|
||||
// Test printing freeze transform
|
||||
auto freeze_transform = make_freeze_transform(number<5>{});
|
||||
|
||||
std::string output = CapturePrintOutput(freeze_transform);
|
||||
|
||||
// Verify it contains the freeze identifier and key fields
|
||||
EXPECT_TRUE(output.find("freeze{") == 0);
|
||||
EXPECT_TRUE(output.find("low_idx_") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == '}');
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
45
test/ck_tile/utility/print/test_print_sequence.cpp
Normal file
45
test/ck_tile/utility/print/test_print_sequence.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_print_common.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
class PrintSequenceTest : public PrintTest
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(PrintSequenceTest, PrintSimpleSequence)
|
||||
{
|
||||
// Test printing sequence<1, 5, 8>
|
||||
constexpr auto seq = sequence<1, 5, 8>{};
|
||||
|
||||
std::string output = CapturePrintOutput(seq);
|
||||
|
||||
// Verify the output format
|
||||
EXPECT_EQ(output, "sequence<1, 5, 8>");
|
||||
}
|
||||
|
||||
TEST_F(PrintSequenceTest, PrintSingleElementSequence)
|
||||
{
|
||||
// Test printing sequence<42>
|
||||
constexpr auto seq = sequence<42>{};
|
||||
|
||||
std::string output = CapturePrintOutput(seq);
|
||||
|
||||
EXPECT_EQ(output, "sequence<42>");
|
||||
}
|
||||
|
||||
TEST_F(PrintSequenceTest, PrintEmptySequence)
|
||||
{
|
||||
// Test printing sequence<> (empty sequence)
|
||||
constexpr auto seq = sequence<>{};
|
||||
|
||||
std::string output = CapturePrintOutput(seq);
|
||||
|
||||
EXPECT_EQ(output, "sequence<>");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,89 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_print_common.hpp"
|
||||
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
class PrintStaticEncodingPatternTest : public PrintTest
|
||||
{
|
||||
protected:
|
||||
void TestY0Y1Y2(const std::string& output, auto Y0, auto Y1, auto Y2)
|
||||
{
|
||||
std::stringstream expected;
|
||||
expected << "<Y0, Y1, Y2>: <" << Y0 << ", " << Y1 << ", " << Y2 << ">";
|
||||
EXPECT_TRUE(output.find(expected.str()) != std::string::npos);
|
||||
}
|
||||
void TestX0X1(const std::string& output, auto X0, auto X1)
|
||||
{
|
||||
std::stringstream expected;
|
||||
expected << "<X0, X1>: <" << X0 << ", " << X1 << ">";
|
||||
EXPECT_TRUE(output.find(expected.str()) != std::string::npos);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(PrintStaticEncodingPatternTest, PrintThreadRakedPattern)
|
||||
{
|
||||
// Test printing thread raked pattern
|
||||
using PatternType =
|
||||
TileDistributionEncodingPattern2D<64, 8, 16, 4, tile_distribution_pattern::thread_raked>;
|
||||
PatternType pattern;
|
||||
|
||||
std::string output = CapturePrintOutput(pattern);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("BlockSize:64") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("YPerTile:8") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("XPerTile:16") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("VecSize:4") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("thread_raked") != std::string::npos);
|
||||
TestY0Y1Y2(output, PatternType::Y0, PatternType::Y1, PatternType::Y2);
|
||||
TestX0X1(output, PatternType::X0, PatternType::X1);
|
||||
}
|
||||
|
||||
TEST_F(PrintStaticEncodingPatternTest, PrintWarpRakedPattern)
|
||||
{
|
||||
// Test printing warp raked pattern
|
||||
using PatternType =
|
||||
TileDistributionEncodingPattern2D<128, 16, 32, 8, tile_distribution_pattern::warp_raked>;
|
||||
PatternType pattern;
|
||||
|
||||
std::string output = CapturePrintOutput(pattern);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("BlockSize:128") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("YPerTile:16") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("XPerTile:32") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("VecSize:8") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("warp_raked") != std::string::npos);
|
||||
TestY0Y1Y2(output, PatternType::Y0, PatternType::Y1, PatternType::Y2);
|
||||
TestX0X1(output, PatternType::X0, PatternType::X1);
|
||||
}
|
||||
|
||||
TEST_F(PrintStaticEncodingPatternTest, PrintBlockRakedPattern)
|
||||
{
|
||||
// Test printing block raked pattern
|
||||
using PatternType =
|
||||
TileDistributionEncodingPattern2D<256, 32, 64, 16, tile_distribution_pattern::block_raked>;
|
||||
PatternType pattern;
|
||||
|
||||
std::string output = CapturePrintOutput(pattern);
|
||||
|
||||
// Verify the output contains expected information
|
||||
EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("BlockSize:256") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("YPerTile:32") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("XPerTile:64") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("VecSize:16") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("block_raked") != std::string::npos);
|
||||
TestY0Y1Y2(output, PatternType::Y0, PatternType::Y1, PatternType::Y2);
|
||||
TestX0X1(output, PatternType::X0, PatternType::X1);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
66
test/ck_tile/utility/print/test_print_tuple.cpp
Normal file
66
test/ck_tile/utility/print/test_print_tuple.cpp
Normal file
@@ -0,0 +1,66 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_print_common.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
class PrintTupleTest : public PrintTest
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(PrintTupleTest, PrintSimpleTuple)
|
||||
{
|
||||
// Test printing tuple with numbers
|
||||
auto tup = make_tuple(number<1>{}, number<5>{}, number<8>{});
|
||||
|
||||
std::string output = CapturePrintOutput(tup);
|
||||
|
||||
// Verify the output format matches tuple print implementation
|
||||
EXPECT_TRUE(output.find("tuple<") == 0);
|
||||
EXPECT_TRUE(output.find("1") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("5") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("8") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == '>');
|
||||
}
|
||||
|
||||
TEST_F(PrintTupleTest, PrintSingleElementTuple)
|
||||
{
|
||||
// Test printing tuple with single element
|
||||
auto tup = make_tuple(number<42>{});
|
||||
|
||||
std::string output = CapturePrintOutput(tup);
|
||||
|
||||
EXPECT_TRUE(output.find("tuple<") == 0);
|
||||
EXPECT_TRUE(output.find("42") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == '>');
|
||||
}
|
||||
|
||||
TEST_F(PrintTupleTest, PrintEmptyTuple)
|
||||
{
|
||||
// Test printing empty tuple
|
||||
auto tup = make_tuple();
|
||||
|
||||
std::string output = CapturePrintOutput(tup);
|
||||
|
||||
EXPECT_EQ(output, "tuple<>");
|
||||
}
|
||||
|
||||
TEST_F(PrintTupleTest, PrintMixedTypeTuple)
|
||||
{
|
||||
// Test printing tuple with mixed types (numbers and constants)
|
||||
auto tup = make_tuple(number<10>{}, constant<20>{}, number<30>{});
|
||||
|
||||
std::string output = CapturePrintOutput(tup);
|
||||
|
||||
EXPECT_TRUE(output.find("tuple<") == 0);
|
||||
EXPECT_TRUE(output.find("10") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("20") != std::string::npos);
|
||||
EXPECT_TRUE(output.find("30") != std::string::npos);
|
||||
EXPECT_TRUE(output.back() == '>');
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user