From 897c2bd42296a8e830d0de954cffa44f90abd972 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Fri, 14 Nov 2025 02:43:22 +0000 Subject: [PATCH] Merge commit '4d629cd2b0bb0b4b210881be0db398bcd382f444' into develop --- example/ck_tile/18_flatmm/flatmm_basic.cpp | 2 +- .../ck_tile/18_flatmm/run_flatmm_example.inc | 2 +- .../core/tensor/tensor_adaptor_coordinate.hpp | 112 +++++++++++++++++ .../ck_tile/core/tensor/tensor_coordinate.hpp | 5 + include/ck_tile/core/utility/debug.hpp | 117 ++++++++---------- include/ck_tile/core/utility/print.hpp | 45 +++++++ .../ops/flatmm/kernel/flatmm_kernel.hpp | 16 ++- 7 files changed, 223 insertions(+), 76 deletions(-) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 9155b27dba..cf05abd51c 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -47,7 +47,7 @@ static constexpr inline auto is_row_major(Layout layout_) // mfma_type, 0:32x32, 1:16x16 template -auto shuffle_b(const ck_tile::HostTensor& t) +auto shuffle_b_v0(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index 69bf39f670..4063fe284e 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -103,7 +103,7 @@ int run_flatmm_example_with_layouts(int argc, } else { - return shuffle_b(b_origin_host); + return shuffle_b_v0(b_origin_host); } }(); ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes()); diff --git a/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp b/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp index 0d398d4237..d7b9a466ef 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp @@ -12,6 +12,7 @@ #include "ck_tile/core/container/multi_index.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/utility/print.hpp" namespace ck_tile { @@ -254,4 +255,115 @@ CK_TILE_HOST_DEVICE constexpr bool adaptor_coordinate_is_valid(const Adaptor& ad adaptor_coordinate_is_valid_assuming_top_index_is_valid(adaptor, coord); } +namespace detail { +template , typename SUFFIX = str_literal<>> +struct CK_PRINT_X_; + +template +struct CK_PRINT_X_, str_literal> +{ + template + struct detail; + template + struct detail< + tensor_adaptor_coordinate> + { + using coord_t = + tensor_adaptor_coordinate; + + template + CK_TILE_HOST_DEVICE static constexpr auto get_hidden_format_i() + { + constexpr bool is_bottom = + sequence_any_of(BottomDimensionHiddenIds{}, [](auto b) { return b == I; }); + constexpr bool is_top = + sequence_any_of(TopDimensionHiddenIds{}, [](auto t) { return t == I; }); + constexpr auto d = make_str_literal("%d"); + if constexpr(is_bottom && is_top) + return make_str_literal("_^") + d; + else if constexpr(is_bottom) + return make_str_literal("_") + d; + else if constexpr(is_top) + return make_str_literal("^") + d; + else + return d; + } + template + CK_TILE_HOST_DEVICE static constexpr auto get_hidden_format() + { + constexpr auto sep = make_str_literal(" "); + if constexpr(N == 0) + return str_literal<>{}; + else + return get_hidden_format() + sep + get_hidden_format_i(); + } + CK_TILE_HOST_DEVICE static constexpr auto get_format() + { + constexpr auto d = make_str_literal("%d"); + constexpr auto sep = make_str_literal(" "); + constexpr auto bottom_fmt = + d.template duplicate_n(sep); + constexpr auto top_fmt = d.template duplicate_n(sep); + constexpr auto hidden_fmt = get_hidden_format(); + return make_str_literal("[ __") + bottom_fmt + make_str_literal("__ | ^^") + top_fmt + + make_str_literal("^^ | ") + hidden_fmt + make_str_literal(" ]"); + } + CK_TILE_HOST_DEVICE static constexpr index_t get_num_values() + { + return BottomDimensionHiddenIds::size() + TopDimensionHiddenIds::size() + NDimHidden; + } + + CK_TILE_HOST_DEVICE static constexpr auto get_values(const coord_t& coord) + { + return container_concat( + coord.get_bottom_index(), coord.get_top_index(), coord.get_hidden_index()); + } + }; + + CK_TILE_HOST_DEVICE static constexpr auto get_prefix() + { + constexpr auto fmt_tid = make_str_literal("tid %03d: "); + if constexpr(sizeof...(PREFIXChars) == 0) + return fmt_tid; + else + return fmt_tid + make_str_literal(" ") + str_literal{}; + } + CK_TILE_HOST_DEVICE static constexpr auto get_suffix() + { + constexpr auto lf = make_str_literal("\n"); + if constexpr(sizeof...(SUFFIXChars) == 0) + return lf; + else + return str_literal{} + lf; + } + + template + CK_TILE_HOST_DEVICE void impl(str_literal, + const TArgs& targs, + std::integer_sequence, + Args&&... args) const + { + constexpr auto fmt_wrap_v = get_prefix() + str_literal{} + get_suffix(); +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wformat-nonliteral" + printf(fmt_wrap_v.data, get_thread_id(), args..., targs.at(number())...); +#pragma clang diagnostic pop + } + template + CK_TILE_HOST_DEVICE void operator()(T&& x, Args&&... args) const + { + using detail_t = detail>; + impl(detail_t::get_format(), + detail_t::get_values(std::forward(x)), + std::make_integer_sequence{}, + std::forward(args)...); + } +}; +} // namespace detail + +template +CK_TILE_HOST_DEVICE void print(const tensor_adaptor_coordinate& coord) +{ + detail::CK_PRINT_X_<>{}(coord); +} } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_coordinate.hpp b/include/ck_tile/core/tensor/tensor_coordinate.hpp index 9b8fe731fd..a51da9f844 100644 --- a/include/ck_tile/core/tensor/tensor_coordinate.hpp +++ b/include/ck_tile/core/tensor/tensor_coordinate.hpp @@ -89,4 +89,9 @@ CK_TILE_HOST_DEVICE constexpr bool coordinate_has_valid_offset(const TensorDesc& return adaptor_coordinate_is_valid(tensor_desc, coord); } +template +CK_TILE_HOST_DEVICE void print(const tensor_coordinate& coord) +{ + print(static_cast::Base>(coord)); +} } // namespace ck_tile diff --git a/include/ck_tile/core/utility/debug.hpp b/include/ck_tile/core/utility/debug.hpp index 9f0f931bc8..581b095383 100644 --- a/include/ck_tile/core/utility/debug.hpp +++ b/include/ck_tile/core/utility/debug.hpp @@ -7,6 +7,8 @@ #include #include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/utility/print.hpp" +#include "ck_tile/core/arch/arch.hpp" namespace ck_tile { template @@ -18,48 +20,6 @@ template { } -template -struct str_literal -{ - static constexpr const char data[] = {Xs..., '\0'}; - static constexpr const size_t size = sizeof...(Xs); - - template - CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal /*rhs*/) const - { - return str_literal{}; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal sep) - { - if constexpr(N == 0) - return str_literal<>{}; - else if constexpr(N == 1) - return str_literal{}; - else - return duplicate_n(sep) + str_literal{}; - } -}; - -#define make_str_literal(lit_) \ - std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \ - makeTuple(std::make_index_sequence())) - -template -constexpr std::tuple...> -makeTuple(std::index_sequence) noexcept -{ - return {}; -} -constexpr size_t constexpr_strlen(const char* c) -{ - size_t t = 0; - while(*c++) - ++t; - return t; -} - template struct static_distributed_tensor; @@ -79,17 +39,29 @@ struct CK_PRINTF> { template - CK_TILE_HOST_DEVICE static constexpr auto default_format() + CK_TILE_HOST_DEVICE static constexpr auto default_format_and_type() { if constexpr(std::is_same_v) - return make_str_literal("%8.3f"); + return std::make_tuple(make_str_literal("%8.3f"), T{}); else if constexpr(std::is_same_v) - return make_str_literal("%5d"); + return std::make_tuple(make_str_literal("%5d"), T{}); else if constexpr(std::is_same_v) - return make_str_literal("%5u"); + return std::make_tuple(make_str_literal("%5u"), T{}); + else if constexpr(sizeof(T) == 1) + return std::make_tuple(make_str_literal("0x%02hhx"), uint8_t{}); + else if constexpr(sizeof(T) == 2) + return std::make_tuple(make_str_literal("0x%04hx"), uint16_t{}); + else if constexpr(sizeof(T) == 4) + return std::make_tuple(make_str_literal("0x%08x"), uint32_t{}); else - return make_str_literal("0x%08x"); + static_assert(false, "Unsupported type"); } + template + using default_format_t = + std::remove_reference_t(default_format_and_type()))>; + template + using default_type_t = + std::remove_reference_t(default_format_and_type()))>; CK_TILE_HOST_DEVICE static constexpr auto get_prefix() { @@ -108,49 +80,58 @@ struct CK_PRINTF{} + lf; } - template + template CK_TILE_HOST_DEVICE void impl(const thread_buffer& buf, - std::integer_sequence) const + std::integer_sequence, + Args&&... args) const { - using FMT1 = std::conditional_t()), - str_literal>; + using FMT1 = std:: + conditional_t, str_literal>; constexpr auto fmt_v = FMT1::template duplicate_n(make_str_literal(" ")); constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix(); #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wformat-nonliteral" - printf(fmt_wrap_v.data, get_thread_id(), N, type_convert(buf[Is])...); + printf(fmt_wrap_v.data, + get_thread_id(), + N, + args..., + bit_cast>(type_convert(buf[Is]))...); #pragma clang diagnostic pop } - template - CK_TILE_HOST_DEVICE void operator()(const thread_buffer& buf) const + template + CK_TILE_HOST_DEVICE void operator()(const thread_buffer& buf, Args&&... args) const { using ConvertTo_ = std::conditional_t, T, ConvertTo>; - impl(buf, std::make_integer_sequence{}); + impl( + buf, std::make_integer_sequence{}, std::forward(args)...); } - template - CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor& tensor) const + template + CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor& tensor, + Args&&... args) const { - return operator()(tensor.get_thread_buffer()); + return operator()(tensor.get_thread_buffer(), std::forward(args)...); } }; -template , - typename PREFIX = str_literal<>, - typename SUFFIX = str_literal<>> -struct CK_PRINTF_WARP0 : public CK_PRINTF +template +CK_TILE_HOST_DEVICE void print_warp0(T&& x) { - using base_t = CK_PRINTF; + if(get_thread_id() < get_warp_size()) + print(std::forward(x)); +} +template +struct CK_PRINTF_WARP0 : public CK_PRINTF +{ + using base_t = CK_PRINTF; - template - CK_TILE_HOST_DEVICE void operator()(const T& buf) const + template + CK_TILE_HOST_DEVICE void operator()(const T& buf, Args&&... args) const { if(get_thread_id() < get_warp_size()) - base_t::operator()(buf); + base_t::operator()(buf, std::forward(args)...); } }; diff --git a/include/ck_tile/core/utility/print.hpp b/include/ck_tile/core/utility/print.hpp index 04635959af..b7279a1ef2 100644 --- a/include/ck_tile/core/utility/print.hpp +++ b/include/ck_tile/core/utility/print.hpp @@ -7,6 +7,51 @@ namespace ck_tile { +namespace str_literal_detail { +template +constexpr std::tuple...> +makeTuple(std::index_sequence) noexcept +{ + return {}; +} +constexpr size_t constexpr_strlen(const char* c) +{ + size_t t = 0; + while(*c++) + ++t; + return t; +} +} // namespace str_literal_detail + +template +struct str_literal +{ + static constexpr const char data[] = {Xs..., '\0'}; + static constexpr const size_t size = sizeof...(Xs); + + template + CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal /*rhs*/) const + { + return str_literal{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal sep) + { + if constexpr(N == 0) + return str_literal<>{}; + else if constexpr(N == 1) + return str_literal{}; + else + return duplicate_n(sep) + str_literal{}; + } +}; + +#define make_str_literal(lit_) \ + std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \ + str_literal_detail::makeTuple( \ + std::make_index_sequence())) + /// Declare a ck_tile::print() interface that gets specialized in each header file for types that /// can be printed. template diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index a53a4a499e..7523acc080 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -662,17 +662,21 @@ struct FlatmmKernel const auto scale_m_view = make_naive_tensor_view( kargs.scale_m_ptr.ptr, - make_tuple( - kargs.M / ScaleGranularityM, - ScaleGranularityKA == 0 ? 1 : splitk_batch_offset.splitted_k / ScaleGranularityKA), + make_tuple(kargs.M / ScaleGranularityM, + ScaleGranularityKA == 0 + ? 1 + : splitk_batch_offset.splitted_k / + (ScaleGranularityKA != 0 ? ScaleGranularityKA : 1)), make_tuple(scale_stride_m, 0), number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {}, number<1>{}); const auto scale_n_view = make_naive_tensor_view( kargs.scale_n_ptr.ptr, - make_tuple( - ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB), - kargs.N / ScaleGranularityN), + make_tuple(ScaleGranularityKB == 0 + ? 1 + : (splitk_batch_offset.splitted_k / + (ScaleGranularityKB != 0 ? ScaleGranularityKB : 1)), + kargs.N / ScaleGranularityN), make_tuple(0, scale_stride_n), number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {}, number<1>{});