From c9ba50214150fefa443e16517bf86eae1bbfc14c Mon Sep 17 00:00:00 2001 From: Aleksander Dudek Date: Thu, 2 Oct 2025 01:02:05 -0500 Subject: [PATCH] [CK_TILE] Vector stores for C Column Layout --- .../core/arch/generic_memory_space_atomic.hpp | 3 +- include/ck_tile/core/tensor/buffer_view.hpp | 7 +- .../ops/epilogue/cshuffle_epilogue.hpp | 72 ++++++++++--------- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 2 +- 4 files changed, 49 insertions(+), 35 deletions(-) diff --git a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp index e56bcadcba..9bb8ea37e4 100644 --- a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp +++ b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp @@ -350,7 +350,7 @@ CK_TILE_DEVICE void atomic_add(fp16x2_t* p_dst, fp16x2_t const& x) new_.f162 = add_f16x2_t(cur_v.f162, x); new_v = new_.u32; cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); - } while(cur_v.u32 != old_v); +} while(cur_v.u32 != old_v); #endif } @@ -361,6 +361,7 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer& x) (std::is_same::value && (N == 1)) || (std::is_same::value && (N == 1 || N == 2 || N == 4)) || (std::is_same::value && (N == 1 || N == 2)) || + (std::is_same::value && (N == 2 || N == 4 || N == 8)) || (std::is_same::value && (N == 2 || N == 4 || N == 8)) || (std::is_same::value && (N == 2 || N == 4 || N == 8)) || (std::is_same::value && (N == 4 || N == 8 || N == 16)) || diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 3b747dae84..fd78fb2362 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -631,7 +631,12 @@ struct buffer_view, int32_t> || std::is_same_v, float> || - (std::is_same_v, half_t> && scalar_per_x_vector % 2 == 0); + (std::is_same_v, half_t> && scalar_per_x_vector % 2 == 0) +#if defined(__gfx94__) + || (std::is_same_v, bf16_t> && scalar_per_x_vector % 2 == 0); +#else + ; +#endif #elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) bool constexpr use_amd_buffer_addressing = std::is_same_v, int32_t>; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index e0a39a5aea..a416dcec74 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -114,6 +114,14 @@ struct CShuffleEpilogue static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave); static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave); + static constexpr bool IsERowMajor = + std::is_same_v ? true : false; + static constexpr bool IsEColMajor = + std::is_same_v ? true : false; + + static_assert(std::disjunction_v, bool_constant>, + "Unsupported ELayout!"); + static_assert(NumDTensor == DsLayout::size(), "The size of DsDataType and DsLayout should be the same"); /** @@ -133,12 +141,12 @@ struct CShuffleEpilogue return VectorSizeC; } constexpr index_t max_vector_size = 16; - if constexpr(std::is_same_v) + if constexpr(IsERowMajor) { return std::min(static_cast(NPerIteration), static_cast(max_vector_size / sizeof(ODataType))); } - else if constexpr(std::is_same_v) + else if constexpr(IsEColMajor) { return std::min(static_cast(MPerIteration), static_cast(max_vector_size / sizeof(ODataType))); @@ -160,12 +168,15 @@ struct CShuffleEpilogue constexpr index_t max_vector_size = 16; using DiDataType = remove_cvref_t>; using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) + static_assert(std::is_same_v, "ELayout is not equal to DiLayout!"); + constexpr bool IsDRowMajor = std::is_same_v ? true : false; + constexpr bool IsDColMajor = std::is_same_v ? true : false; + if constexpr(IsDRowMajor) { return std::min(static_cast(NPerIteration), static_cast(max_vector_size / sizeof(DiDataType))); } - else if constexpr(std::is_same_v) + else if constexpr(IsDColMajor) { return std::min(static_cast(MPerIteration), static_cast(max_vector_size / sizeof(DiDataType))); @@ -174,7 +185,6 @@ struct CShuffleEpilogue { static_assert(false, "Unsupported DLayout!"); } - return max_vector_size / sizeof(DiDataType); } /** * @brief Shuffle tile configuration parameters @@ -225,6 +235,22 @@ struct CShuffleEpilogue static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle); static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle); + static constexpr index_t NumYXdlPerWavePerShuffle = + IsERowMajor ? NumMXdlPerWavePerShuffle : NumNXdlPerWavePerShuffle; + static constexpr index_t NumXXdlPerWavePerShuffle = + IsERowMajor ? NumNXdlPerWavePerShuffle : NumMXdlPerWavePerShuffle; + + static constexpr index_t YPerIterationShuffle = + IsERowMajor ? MPerIterationShuffle : NPerIterationShuffle; + static constexpr index_t XPerIterationShuffle = + IsERowMajor ? NPerIterationShuffle : MPerIterationShuffle; + + static constexpr index_t YPerBlock = IsERowMajor ? kMPerBlock : kNPerBlock; + static constexpr index_t XPerBlock = IsERowMajor ? kNPerBlock : kMPerBlock; + + static constexpr index_t YWave = IsERowMajor ? MWave : NWave; + static constexpr index_t XWave = IsERowMajor ? NWave : MWave; + using WG = WarpGemmDispatcher CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor() { - // N is contiguous dimension - if constexpr(std::is_same_v) - { - return make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{})); - } - // M is contiguous dimension - else if constexpr(std::is_same_v) - { - return make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number<1>{}, number{})); - } - else - { - static_assert(false, "Unsupported ELayout!"); - } + return make_naive_tensor_descriptor( + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{})); } CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode() { constexpr auto block_outer_dstr_encoding = tile_distribution_encoding, - tuple, - sequence>, + tuple, + sequence>, tuple>, tuple>, sequence<1, 2>, @@ -585,24 +596,21 @@ struct CShuffleEpilogue auto in_lds_window = make_tile_window( o_lds_block, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}, LdsTileDistr); auto out_lds_window = make_tile_window( o_lds_block, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}); constexpr index_t num_access = SFC::get_num_of_access(); - static_assert(std::is_same_v, - "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); - using TileEncodingPattern = tile_distribution_encoding_pattern_2d; diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 51ad4e3dd1..a9e4da3eb5 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -743,7 +743,7 @@ struct UniversalGemmKernel make_tuple(kargs.M, kargs.N), make_tuple(1, kargs.stride_E), number<1>{}, - number<1>{}); + number{}); } }();