[CK_TILE] Improve device printing (#3198)

* [CK_TILE] Improve device printing

* fix host gtest build

* clean
This commit is contained in:
Yi DING
2025-11-14 09:46:06 +08:00
committed by GitHub
parent 2a73eb3bc0
commit 4a8b17d1a4
4 changed files with 211 additions and 68 deletions

View File

@@ -7,6 +7,8 @@
#include <utility>
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/utility/print.hpp"
#include "ck_tile/core/arch/arch.hpp"
namespace ck_tile {
template <auto... val>
@@ -18,48 +20,6 @@ template <typename... type>
{
}
template <char... Xs>
struct str_literal
{
static constexpr const char data[] = {Xs..., '\0'};
static constexpr const size_t size = sizeof...(Xs);
template <char... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal<Ys...> /*rhs*/) const
{
return str_literal<Xs..., Ys...>{};
}
template <index_t N, char... Ys>
CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal<Ys...> sep)
{
if constexpr(N == 0)
return str_literal<>{};
else if constexpr(N == 1)
return str_literal<Xs...>{};
else
return duplicate_n<N - 1>(sep) + str_literal<Ys..., Xs...>{};
}
};
#define make_str_literal(lit_) \
std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \
makeTuple(std::make_index_sequence<constexpr_strlen(lit_)>()))
template <size_t... Idx>
constexpr std::tuple<std::integral_constant<size_t, Idx>...>
makeTuple(std::index_sequence<Idx...>) noexcept
{
return {};
}
constexpr size_t constexpr_strlen(const char* c)
{
size_t t = 0;
while(*c++)
++t;
return t;
}
template <typename DataType_, typename StaticTileDistribution_>
struct static_distributed_tensor;
@@ -79,17 +39,29 @@ struct CK_PRINTF<ConvertTo,
str_literal<SUFFIXChars...>>
{
template <typename T>
CK_TILE_HOST_DEVICE static constexpr auto default_format()
CK_TILE_HOST_DEVICE static constexpr auto default_format_and_type()
{
if constexpr(std::is_same_v<T, float>)
return make_str_literal("%8.3f");
return std::make_tuple(make_str_literal("%8.3f"), T{});
else if constexpr(std::is_same_v<T, int>)
return make_str_literal("%5d");
return std::make_tuple(make_str_literal("%5d"), T{});
else if constexpr(std::is_same_v<T, unsigned int>)
return make_str_literal("%5u");
return std::make_tuple(make_str_literal("%5u"), T{});
else if constexpr(sizeof(T) == 1)
return std::make_tuple(make_str_literal("0x%02hhx"), uint8_t{});
else if constexpr(sizeof(T) == 2)
return std::make_tuple(make_str_literal("0x%04hx"), uint16_t{});
else if constexpr(sizeof(T) == 4)
return std::make_tuple(make_str_literal("0x%08x"), uint32_t{});
else
return make_str_literal("0x%08x");
static_assert(false, "Unsupported type");
}
template <typename T>
using default_format_t =
std::remove_reference_t<decltype(std::get<0>(default_format_and_type<T>()))>;
template <typename T>
using default_type_t =
std::remove_reference_t<decltype(std::get<1>(default_format_and_type<T>()))>;
CK_TILE_HOST_DEVICE static constexpr auto get_prefix()
{
@@ -108,49 +80,58 @@ struct CK_PRINTF<ConvertTo,
return str_literal<SUFFIXChars...>{} + lf;
}
template <typename T, index_t N, typename Y, index_t... Is>
template <typename T, index_t N, typename Y, index_t... Is, typename... Args>
CK_TILE_HOST_DEVICE void impl(const thread_buffer<T, N>& buf,
std::integer_sequence<index_t, Is...>) const
std::integer_sequence<index_t, Is...>,
Args&&... args) const
{
using FMT1 = std::conditional_t<sizeof...(FMTChars) == 0,
decltype(default_format<Y>()),
str_literal<FMTChars...>>;
using FMT1 = std::
conditional_t<sizeof...(FMTChars) == 0, default_format_t<Y>, str_literal<FMTChars...>>;
constexpr auto fmt_v = FMT1::template duplicate_n<N>(make_str_literal(" "));
constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix();
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
printf(fmt_wrap_v.data, get_thread_id(), N, type_convert<Y>(buf[Is])...);
printf(fmt_wrap_v.data,
get_thread_id(),
N,
args...,
bit_cast<default_type_t<Y>>(type_convert<Y>(buf[Is]))...);
#pragma clang diagnostic pop
}
template <typename T, index_t N>
CK_TILE_HOST_DEVICE void operator()(const thread_buffer<T, N>& buf) const
template <typename T, index_t N, typename... Args>
CK_TILE_HOST_DEVICE void operator()(const thread_buffer<T, N>& buf, Args&&... args) const
{
using ConvertTo_ = std::conditional_t<std::is_same_v<ConvertTo, void>, T, ConvertTo>;
impl<T, N, ConvertTo_>(buf, std::make_integer_sequence<index_t, N>{});
impl<T, N, ConvertTo_>(
buf, std::make_integer_sequence<index_t, N>{}, std::forward<Args>(args)...);
}
template <typename... TS>
CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor<TS...>& tensor) const
template <typename... TS, typename... Args>
CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor<TS...>& tensor,
Args&&... args) const
{
return operator()(tensor.get_thread_buffer());
return operator()(tensor.get_thread_buffer(), std::forward<Args>(args)...);
}
};
template <typename ConvertTo = void,
typename FMT = str_literal<>,
typename PREFIX = str_literal<>,
typename SUFFIX = str_literal<>>
struct CK_PRINTF_WARP0 : public CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>
template <typename T>
CK_TILE_HOST_DEVICE void print_warp0(T&& x)
{
using base_t = CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>;
if(get_thread_id() < get_warp_size())
print(std::forward<T>(x));
}
template <typename... Ts>
struct CK_PRINTF_WARP0 : public CK_PRINTF<Ts...>
{
using base_t = CK_PRINTF<Ts...>;
template <typename T>
CK_TILE_HOST_DEVICE void operator()(const T& buf) const
template <typename T, typename... Args>
CK_TILE_HOST_DEVICE void operator()(const T& buf, Args&&... args) const
{
if(get_thread_id() < get_warp_size())
base_t::operator()(buf);
base_t::operator()(buf, std::forward<Args>(args)...);
}
};

View File

@@ -7,6 +7,51 @@
namespace ck_tile {
namespace str_literal_detail {
template <size_t... Idx>
constexpr std::tuple<std::integral_constant<size_t, Idx>...>
makeTuple(std::index_sequence<Idx...>) noexcept
{
return {};
}
constexpr size_t constexpr_strlen(const char* c)
{
size_t t = 0;
while(*c++)
++t;
return t;
}
} // namespace str_literal_detail
template <char... Xs>
struct str_literal
{
static constexpr const char data[] = {Xs..., '\0'};
static constexpr const size_t size = sizeof...(Xs);
template <char... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal<Ys...> /*rhs*/) const
{
return str_literal<Xs..., Ys...>{};
}
template <size_t N, char... Ys>
CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal<Ys...> sep)
{
if constexpr(N == 0)
return str_literal<>{};
else if constexpr(N == 1)
return str_literal<Xs...>{};
else
return duplicate_n<N - 1>(sep) + str_literal<Ys..., Xs...>{};
}
};
#define make_str_literal(lit_) \
std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \
str_literal_detail::makeTuple( \
std::make_index_sequence<str_literal_detail::constexpr_strlen(lit_)>()))
/// Declare a ck_tile::print() interface that gets specialized in each header file for types that
/// can be printed.
template <typename T>