From c69539fe3c5fd75b7fa67ec4f7746cea0cf4c6bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Thu, 28 Aug 2025 15:01:55 +0000 Subject: [PATCH] Optimize LDS write order for packed cast. --- .../threadwise_tensor_slice_transfer.hpp | 139 ++++++++++-------- 1 file changed, 76 insertions(+), 63 deletions(-) diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index bd440c3f45..60e894ba09 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -13,6 +13,42 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" namespace ck { + +// Helper function to compare multi-indices element by element +template +constexpr bool indices_equal(const Idx1& idx1, const Idx2& idx2) +{ + if constexpr (Idx1::Size() != Idx2::Size()) { + return false; + } else { + bool equal = true; + static_for<0, Idx1::Size(), 1>{}([&](auto i) { + equal = equal && (idx1.At(i) == idx2.At(i)); + }); + return equal; + } +} + +template +struct LinearIndexFinder +{ + static constexpr index_t find() + { + if constexpr (SearchIdx >= MaxIdx) { + return index_t(-1); // Not found + } else { + constexpr auto dst_md = SpaceFillingCurveDst::GetIndex(Number{}); + constexpr auto src_md = SpaceFillingCurveSrc::GetIndex(Number{}); + + if constexpr (indices_equal(dst_md, src_md)) { + return SearchIdx; + } else { + return LinearIndexFinder::find(); + } + } + } +}; + // Assume: // 1. src: // 1. SrcDesc is known at compile-time @@ -44,7 +80,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3_pass_through static constexpr index_t nDim = SliceLengths::Size(); - static constexpr bool SerpentineAccessPattern = false; + static constexpr bool SerpentineAccessPattern = true; static constexpr bool float_input_and_bf16_output_ = std::is_same_v && std::is_same_v; @@ -96,86 +132,63 @@ struct ThreadwiseTensorSliceTransfer_v1r3_pass_through constexpr auto dst_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - using SpaceFillingCurve = SpaceFillingCurve, SerpentineAccessPattern>; - // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector? - static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, - "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + using SpaceFillingCurveSrc = SpaceFillingCurve, + false>; - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + static_assert(DstScalarPerVector == SpaceFillingCurveDst::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurveDst::ScalarPerVector"); + + constexpr auto num_access = SpaceFillingCurveDst::GetNumOfAccess(); static_assert(std::is_same_v>, "wrong! DimAccessOrder must be the identity sequence <0, 1, 2, 3, 4, 5, 6, 7>"); - static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector"); + static_assert(1 == SpaceFillingCurveDst::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector"); static_assert(1 == DstScalarPerVector, "wrong!1 != DstScalarPerVector"); - constexpr index_t num_pairs = num_access / 2; - constexpr bool has_odd_element = (num_access % 2 == 1); + static_assert(SpaceFillingCurveDst::GetNumOfAccess() == SpaceFillingCurveSrc::GetNumOfAccess(), + "wrong! SpaceFillingCurveDst and SpaceFillingCurveSrc must have the same number of access."); - // TODO: Enable also odd number of elements. - static_assert(!has_odd_element, "wrong!Slice should have even number of elements."); - - static_for<0, num_pairs, 1>{}([&](auto i_pair) - { - constexpr auto idx_md = SpaceFillingCurve::GetIndex(i_pair); - constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md); - - union - { - float src_float; - bhalf2_t src_bf16x2; - } packed_value; - - packed_value.src_float = src_buf[Number{}]; + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr index_t idx_src_1d = LinearIndexFinder::find(); + static_assert(idx_src_1d != index_t(-1), "wrong! Cannot find linear index."); + // Map linear index to the packed BF16 index + constexpr index_t idx_src_1d_packed = idx_src_1d / 2; + constexpr index_t pair_index = idx_src_1d % 2; + + constexpr auto idx_md_src = SpaceFillingCurveSrc::GetIndex(Number{}); + constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_src); + + union + { + float fp32; + ck::bhalf2_t bf16x2; + } converter; + converter.fp32 = src_buf[Number{}]; + const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); - constexpr auto idx_1d_0 = I2 * i_pair; - constexpr auto idx_1d_1 = I2 * i_pair + I1; - - const auto dst_offset_0 = dst_coord_.GetOffset(); + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + converter.bf16x2[pair_index]); + + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurveDst::GetForwardStep(idx_1d); - // Move to the next dst coordinate - constexpr auto forward_step_0 = SpaceFillingCurve::GetForwardStep(idx_1d_0); move_tensor_coordinate( - dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step_0)); - - const auto dst_offset_1 = dst_coord_.GetOffset(); - - if (dst_offset_1 - dst_offset_0 == 1) - { - // Store the packed value directly since we have consequtive locations is the dst buffer. - dst_buf.template Update( - dst_coord_.GetOffset(), - is_dst_valid, - packed_value.src_bf16x2); - } - else - { - // Store the packed value into the dst buffer one by one. - dst_buf.template Update( - dst_offset_0, - is_dst_valid, - packed_value.src_bf16x2[0]); - - dst_buf.template Update( - dst_offset_1, - is_dst_valid, - packed_value.src_bf16x2[1]); - } - - // Move to next dst coordinate, unless this was the last pair - if constexpr(i_pair.value != num_pairs - 1) - { - - constexpr auto forward_step_1 = SpaceFillingCurve::GetForwardStep(idx_1d_1); - move_tensor_coordinate( - dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step_1)); + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); } });