From 200d2b22d413d78aee2f26a404699e0f54435446 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 25 Mar 2024 19:45:07 +0000 Subject: [PATCH] fix scratch in fp8 kernel --- .../ck_tile/core/container/thread_buffer.hpp | 83 +++++++++++++++++-- include/ck_tile/core/numeric/vector_type.hpp | 14 ++++ include/ck_tile/core/tensor/shuffle_tile.hpp | 4 +- .../ck_tile/ops/gemm/warp/warp_gemm_impl.hpp | 10 +-- 4 files changed, 95 insertions(+), 16 deletions(-) diff --git a/include/ck_tile/core/container/thread_buffer.hpp b/include/ck_tile/core/container/thread_buffer.hpp index 3c3894c148..a7dad5233b 100644 --- a/include/ck_tile/core/container/thread_buffer.hpp +++ b/include/ck_tile/core/container/thread_buffer.hpp @@ -59,6 +59,64 @@ struct thread_buffer { template CK_TILE_HOST_DEVICE constexpr auto& at(number) { return get(I); } template CK_TILE_HOST_DEVICE constexpr const auto& at(number) const { return get(I); } + template ::value, bool>::type = false> + CK_TILE_HOST_DEVICE constexpr auto _get_as() const + { + using X = remove_cvref_t; + + constexpr index_t kSPerX = vector_traits::vector_size; + static_assert(N % kSPerX == 0); + + union { + thread_buffer data {}; + // tuple_array sub_data; + value_type sub_data[N]; + } vx; + static_for<0, N, 1>{}( + [&](auto j) { vx.sub_data[j] = data[j]; }); + return vx.data; + } + + template ::value, bool>::type = false> + CK_TILE_HOST_DEVICE const constexpr remove_reference_t _get_as(number is) const + { + using X = remove_cvref_t; + + constexpr index_t kSPerX = vector_traits::vector_size; + + union { + X_ data {}; + tuple_array sub_data; + } vx; + static_for<0, kSPerX, 1>{}( + [&](auto j) { vx.sub_data(j) = operator[]((is * number{}) + j); }); + return vx.data; + } + +#if 0 + template ::value, bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void _set_as(number is, X_ x) + { + using X = remove_cvref_t; + + constexpr index_t kSPerX = vector_traits::vector_size; + + union { + X_ data; + tuple_array sub_data; + } vx {x}; + + static_for<0, kSPerX, 1>{}( + [&](auto j) { operator()((is * number{}) + j) = vx.sub_data[j]; }); + } +#endif + + #define TB_COMMON_AS() \ static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \ constexpr int vx = sizeof(value_type) * N / sizeof(Tx) @@ -67,19 +125,26 @@ struct thread_buffer { CK_TILE_HOST_DEVICE auto & get_as() {TB_COMMON_AS(); return reinterpret_cast&>(data);} template - CK_TILE_HOST_DEVICE const auto & get_as() const {TB_COMMON_AS(); + CK_TILE_HOST_DEVICE constexpr auto get_as() const {TB_COMMON_AS(); + if constexpr(sizeof(value_type) <= 1 ) + return _get_as(); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future + else return reinterpret_cast&>(data);} - template - CK_TILE_HOST_DEVICE auto & get_as(index_t i) {TB_COMMON_AS(); - return reinterpret_cast&>(data).get(i);} - template - CK_TILE_HOST_DEVICE const auto & get_as(index_t i) const {TB_COMMON_AS(); - return reinterpret_cast&>(data).get(i);} + template + CK_TILE_HOST_DEVICE auto & get_as(number) {TB_COMMON_AS(); + return reinterpret_cast&>(data).get(number{});} + template + CK_TILE_HOST_DEVICE constexpr auto get_as(number) const {TB_COMMON_AS(); + if constexpr(sizeof(value_type) <= 1 ) + return _get_as(number{}); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future + else + return reinterpret_cast&>(data).get(number{});} template CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) - { TB_COMMON_AS(); reinterpret_cast&>(data).at(i) = x; } + { TB_COMMON_AS(); reinterpret_cast&>(data).at(i) = x; } template CK_TILE_HOST_DEVICE constexpr void set_as(number, const Tx & x) - { TB_COMMON_AS(); reinterpret_cast&>(data).at(number{}) = x; } + { TB_COMMON_AS(); reinterpret_cast&>(data).at(number{}) = x; } + #undef TB_COMMON_AS }; // clang-format on diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 9d09e06230..85d9be1c94 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -38,6 +38,16 @@ struct ext_vector static_assert(!std::is_class_v); using type = value_type __attribute__((ext_vector_type(N))); // this is danguous }; + +template +struct ext_vector +{ + static constexpr index_t N = Vs_ * N_; + using value_type = typename native_t>::type; + static_assert(!std::is_class_v); + using type = value_type __attribute__((ext_vector_type(N))); // this is danguous +}; + } // namespace impl template @@ -60,6 +70,10 @@ struct vector_traits static constexpr index_t vector_size = N; }; +template +using has_same_scalar_type = std::is_same>::scalar_type, + typename vector_traits>::scalar_type>; + // below are some pre-defines of ext_vector_type // attention! 2 vector type could be just the same type // fp64 diff --git a/include/ck_tile/core/tensor/shuffle_tile.hpp b/include/ck_tile/core/tensor/shuffle_tile.hpp index 502b7560a4..baf009add2 100644 --- a/include/ck_tile/core/tensor/shuffle_tile.hpp +++ b/include/ck_tile/core/tensor/shuffle_tile.hpp @@ -119,8 +119,8 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT static_assert(in_offset % vec_length_in == 0); in_vectors(i).template get_as()(I0) = - in_tensor.get_thread_buffer().template get_as( - number{}); + in_tensor.get_thread_buffer() + .template get_as()[number{}]; }); // transpose diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp index 843d091c48..eb9dbf127d 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp @@ -39,9 +39,9 @@ struct WarpGemmImpl constexpr auto I0 = number<0>{}; - const auto a_vec = a.get_thread_buffer().template get_as(I0); - const auto b_vec = b.get_thread_buffer().template get_as(I0); - auto c_vec = c.get_thread_buffer().template get_as(I0); + const auto a_vec = a.get_thread_buffer().template get_as()[I0]; + const auto b_vec = b.get_thread_buffer().template get_as()[I0]; + auto c_vec = c.get_thread_buffer().template get_as()[I0]; // c_vec += a_vec * b_vec WarpGemmAttribute{}(c_vec, a_vec, b_vec); @@ -59,8 +59,8 @@ struct WarpGemmImpl constexpr auto I0 = number<0>{}; - const auto a_vec = a.get_thread_buffer().template get_as(I0); - const auto b_vec = b.get_thread_buffer().template get_as(I0); + const auto a_vec = a.get_thread_buffer().template get_as()[I0]; + const auto b_vec = b.get_thread_buffer().template get_as()[I0]; // c_vec = a_vec * b_vec auto c_vec = WarpGemmAttribute{}(a_vec, b_vec);