mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Improve device printing (#3198)
* [CK_TILE] Improve device printing * fix host gtest build * clean
This commit is contained in:
@@ -7,6 +7,8 @@
|
||||
#include <utility>
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <auto... val>
|
||||
@@ -18,48 +20,6 @@ template <typename... type>
|
||||
{
|
||||
}
|
||||
|
||||
template <char... Xs>
|
||||
struct str_literal
|
||||
{
|
||||
static constexpr const char data[] = {Xs..., '\0'};
|
||||
static constexpr const size_t size = sizeof...(Xs);
|
||||
|
||||
template <char... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal<Ys...> /*rhs*/) const
|
||||
{
|
||||
return str_literal<Xs..., Ys...>{};
|
||||
}
|
||||
|
||||
template <index_t N, char... Ys>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal<Ys...> sep)
|
||||
{
|
||||
if constexpr(N == 0)
|
||||
return str_literal<>{};
|
||||
else if constexpr(N == 1)
|
||||
return str_literal<Xs...>{};
|
||||
else
|
||||
return duplicate_n<N - 1>(sep) + str_literal<Ys..., Xs...>{};
|
||||
}
|
||||
};
|
||||
|
||||
#define make_str_literal(lit_) \
|
||||
std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \
|
||||
makeTuple(std::make_index_sequence<constexpr_strlen(lit_)>()))
|
||||
|
||||
template <size_t... Idx>
|
||||
constexpr std::tuple<std::integral_constant<size_t, Idx>...>
|
||||
makeTuple(std::index_sequence<Idx...>) noexcept
|
||||
{
|
||||
return {};
|
||||
}
|
||||
constexpr size_t constexpr_strlen(const char* c)
|
||||
{
|
||||
size_t t = 0;
|
||||
while(*c++)
|
||||
++t;
|
||||
return t;
|
||||
}
|
||||
|
||||
template <typename DataType_, typename StaticTileDistribution_>
|
||||
struct static_distributed_tensor;
|
||||
|
||||
@@ -79,17 +39,29 @@ struct CK_PRINTF<ConvertTo,
|
||||
str_literal<SUFFIXChars...>>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto default_format()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto default_format_and_type()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, float>)
|
||||
return make_str_literal("%8.3f");
|
||||
return std::make_tuple(make_str_literal("%8.3f"), T{});
|
||||
else if constexpr(std::is_same_v<T, int>)
|
||||
return make_str_literal("%5d");
|
||||
return std::make_tuple(make_str_literal("%5d"), T{});
|
||||
else if constexpr(std::is_same_v<T, unsigned int>)
|
||||
return make_str_literal("%5u");
|
||||
return std::make_tuple(make_str_literal("%5u"), T{});
|
||||
else if constexpr(sizeof(T) == 1)
|
||||
return std::make_tuple(make_str_literal("0x%02hhx"), uint8_t{});
|
||||
else if constexpr(sizeof(T) == 2)
|
||||
return std::make_tuple(make_str_literal("0x%04hx"), uint16_t{});
|
||||
else if constexpr(sizeof(T) == 4)
|
||||
return std::make_tuple(make_str_literal("0x%08x"), uint32_t{});
|
||||
else
|
||||
return make_str_literal("0x%08x");
|
||||
static_assert(false, "Unsupported type");
|
||||
}
|
||||
template <typename T>
|
||||
using default_format_t =
|
||||
std::remove_reference_t<decltype(std::get<0>(default_format_and_type<T>()))>;
|
||||
template <typename T>
|
||||
using default_type_t =
|
||||
std::remove_reference_t<decltype(std::get<1>(default_format_and_type<T>()))>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_prefix()
|
||||
{
|
||||
@@ -108,49 +80,58 @@ struct CK_PRINTF<ConvertTo,
|
||||
return str_literal<SUFFIXChars...>{} + lf;
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename Y, index_t... Is>
|
||||
template <typename T, index_t N, typename Y, index_t... Is, typename... Args>
|
||||
CK_TILE_HOST_DEVICE void impl(const thread_buffer<T, N>& buf,
|
||||
std::integer_sequence<index_t, Is...>) const
|
||||
std::integer_sequence<index_t, Is...>,
|
||||
Args&&... args) const
|
||||
{
|
||||
using FMT1 = std::conditional_t<sizeof...(FMTChars) == 0,
|
||||
decltype(default_format<Y>()),
|
||||
str_literal<FMTChars...>>;
|
||||
using FMT1 = std::
|
||||
conditional_t<sizeof...(FMTChars) == 0, default_format_t<Y>, str_literal<FMTChars...>>;
|
||||
constexpr auto fmt_v = FMT1::template duplicate_n<N>(make_str_literal(" "));
|
||||
constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix();
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
||||
printf(fmt_wrap_v.data, get_thread_id(), N, type_convert<Y>(buf[Is])...);
|
||||
printf(fmt_wrap_v.data,
|
||||
get_thread_id(),
|
||||
N,
|
||||
args...,
|
||||
bit_cast<default_type_t<Y>>(type_convert<Y>(buf[Is]))...);
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_HOST_DEVICE void operator()(const thread_buffer<T, N>& buf) const
|
||||
template <typename T, index_t N, typename... Args>
|
||||
CK_TILE_HOST_DEVICE void operator()(const thread_buffer<T, N>& buf, Args&&... args) const
|
||||
{
|
||||
using ConvertTo_ = std::conditional_t<std::is_same_v<ConvertTo, void>, T, ConvertTo>;
|
||||
impl<T, N, ConvertTo_>(buf, std::make_integer_sequence<index_t, N>{});
|
||||
impl<T, N, ConvertTo_>(
|
||||
buf, std::make_integer_sequence<index_t, N>{}, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... TS>
|
||||
CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor<TS...>& tensor) const
|
||||
template <typename... TS, typename... Args>
|
||||
CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor<TS...>& tensor,
|
||||
Args&&... args) const
|
||||
{
|
||||
return operator()(tensor.get_thread_buffer());
|
||||
return operator()(tensor.get_thread_buffer(), std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ConvertTo = void,
|
||||
typename FMT = str_literal<>,
|
||||
typename PREFIX = str_literal<>,
|
||||
typename SUFFIX = str_literal<>>
|
||||
struct CK_PRINTF_WARP0 : public CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void print_warp0(T&& x)
|
||||
{
|
||||
using base_t = CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>;
|
||||
if(get_thread_id() < get_warp_size())
|
||||
print(std::forward<T>(x));
|
||||
}
|
||||
template <typename... Ts>
|
||||
struct CK_PRINTF_WARP0 : public CK_PRINTF<Ts...>
|
||||
{
|
||||
using base_t = CK_PRINTF<Ts...>;
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(const T& buf) const
|
||||
template <typename T, typename... Args>
|
||||
CK_TILE_HOST_DEVICE void operator()(const T& buf, Args&&... args) const
|
||||
{
|
||||
if(get_thread_id() < get_warp_size())
|
||||
base_t::operator()(buf);
|
||||
base_t::operator()(buf, std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -7,6 +7,51 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace str_literal_detail {
|
||||
template <size_t... Idx>
|
||||
constexpr std::tuple<std::integral_constant<size_t, Idx>...>
|
||||
makeTuple(std::index_sequence<Idx...>) noexcept
|
||||
{
|
||||
return {};
|
||||
}
|
||||
constexpr size_t constexpr_strlen(const char* c)
|
||||
{
|
||||
size_t t = 0;
|
||||
while(*c++)
|
||||
++t;
|
||||
return t;
|
||||
}
|
||||
} // namespace str_literal_detail
|
||||
|
||||
template <char... Xs>
|
||||
struct str_literal
|
||||
{
|
||||
static constexpr const char data[] = {Xs..., '\0'};
|
||||
static constexpr const size_t size = sizeof...(Xs);
|
||||
|
||||
template <char... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal<Ys...> /*rhs*/) const
|
||||
{
|
||||
return str_literal<Xs..., Ys...>{};
|
||||
}
|
||||
|
||||
template <size_t N, char... Ys>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal<Ys...> sep)
|
||||
{
|
||||
if constexpr(N == 0)
|
||||
return str_literal<>{};
|
||||
else if constexpr(N == 1)
|
||||
return str_literal<Xs...>{};
|
||||
else
|
||||
return duplicate_n<N - 1>(sep) + str_literal<Ys..., Xs...>{};
|
||||
}
|
||||
};
|
||||
|
||||
#define make_str_literal(lit_) \
|
||||
std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \
|
||||
str_literal_detail::makeTuple( \
|
||||
std::make_index_sequence<str_literal_detail::constexpr_strlen(lit_)>()))
|
||||
|
||||
/// Declare a ck_tile::print() interface that gets specialized in each header file for types that
|
||||
/// can be printed.
|
||||
template <typename T>
|
||||
|
||||
Reference in New Issue
Block a user