diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 60e894ba09..094482092a 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -82,9 +82,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3_pass_through static constexpr bool SerpentineAccessPattern = true; - static constexpr bool float_input_and_bf16_output_ = - std::is_same_v && std::is_same_v; - using Index = MultiIndex; using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); @@ -101,6 +98,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3_pass_through "wrong! SrcDesc need to known at compile-time"); static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, "wrong! Not divisible"); + + // For now, SrcData must be float and DstData must be ck::bhalf_t + static_assert(std::is_same_v, + "wrong! SrcData must be float"); + static_assert(std::is_same_v, + "wrong! DstData must be bhalf_t"); } __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) @@ -273,7 +276,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; - static constexpr bool SerpentineAccessPattern = false; + + // Currently these must match, either true or false for both. + // Otherwise we increase the register usage and end-up using the scratch memory. + // The linear pattern seems to be better tahn serpentine. + // Howeverm for LDS writes, serpentine pattern could be better. + static constexpr bool SerpentineAccessPatternDst = true; + static constexpr bool SerpentineAccessPatternSrc = true; static constexpr index_t nDim = SliceLengths::Size(); @@ -332,14 +341,24 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast constexpr auto dst_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - using SpaceFillingCurve = SpaceFillingCurve, - SerpentineAccessPattern>; - - static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector"); + SerpentineAccessPatternDst>; - constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess(); + using SpaceFillingCurveSrc = SpaceFillingCurve, + SerpentineAccessPatternSrc>; + + static_assert(1 == SpaceFillingCurveDst::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurveDst::ScalarPerVector"); + + static_assert(SpaceFillingCurveDst::GetNumOfAccess() == SpaceFillingCurveSrc::GetNumOfAccess(), + "wrong! SpaceFillingCurveDst and SpaceFillingCurveSrc must have the same number of access."); + + constexpr auto num_access = SpaceFillingCurveDst::GetNumOfAccess(); constexpr index_t num_pairs = num_access / 2; constexpr bool has_odd_element = (num_access % 2 == 1); @@ -349,13 +368,23 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast ck::float2_t float2_buffer; static_for<0, num_pairs, 1>{}([&](auto i_pair) { - constexpr auto idx_1d_0 = I2 * i_pair; - constexpr auto idx_1d_1 = I2 * i_pair + I1; - constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0); - constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1); - - constexpr index_t src_offset_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_0); - constexpr index_t src_offset_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_1); + // First of the pair of the float values + constexpr auto idx_1d_dst_0 = I2 * i_pair; + constexpr index_t idx_1d_src_0 = SerpentineAccessPatternDst != SerpentineAccessPatternSrc + ? LinearIndexFinder::find() + : idx_1d_dst_0.value; + static_assert(idx_1d_src_0 != index_t(-1), "wrong! Cannot find first linear index."); + constexpr auto idx_md_src_0 = SpaceFillingCurveSrc::GetIndex(Number{}); + constexpr index_t src_offset_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_src_0); + + // Second pf the pair of the float values + constexpr auto idx_1d_dst_1 = I2 * i_pair + I1; + constexpr index_t idx_1d_src_1 = SerpentineAccessPatternDst != SerpentineAccessPatternSrc + ? LinearIndexFinder::find() + : idx_1d_dst_1.value; + static_assert(idx_1d_src_1 != index_t(-1), "wrong! Cannot find first linear index."); + constexpr auto idx_md_src_1 = SpaceFillingCurveSrc::GetIndex(Number{}); + constexpr index_t src_offset_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_src_1); if constexpr (src_offset_1 - src_offset_0 == 1) { @@ -373,42 +402,29 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + // Store the packed value into the dst buffer one by one. const auto dst_offset_0 = dst_coord_.GetOffset(); + dst_buf.template Update( + dst_offset_0, + is_dst_valid, + packed_value[0]); // Move to the next dst coordinate - constexpr auto forward_step_0 = SpaceFillingCurve::GetForwardStep(idx_1d_0); + constexpr auto forward_step_0 = SpaceFillingCurveDst::GetForwardStep(idx_1d_dst_0); move_tensor_coordinate( dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step_0)); const auto dst_offset_1 = dst_coord_.GetOffset(); - - if (dst_offset_1 - dst_offset_0 == 1) - { - // Store the packed value directly since we have consequtive locations is the dst buffer. - dst_buf.template Update( - dst_coord_.GetOffset(), - is_dst_valid, - packed_value); - } - else - { - // Store the packed value into the dst buffer one by one. - dst_buf.template Update( - dst_offset_0, - is_dst_valid, - packed_value[0]); - - dst_buf.template Update( - dst_offset_1, - is_dst_valid, - packed_value[1]); - } + dst_buf.template Update( + dst_offset_1, + is_dst_valid, + packed_value[1]); // Move to next dst coordinate, unless this was the last pair if constexpr(i_pair.value != num_pairs - 1) { - constexpr auto forward_step_1 = SpaceFillingCurve::GetForwardStep(idx_1d_1); + constexpr auto forward_step_1 = SpaceFillingCurveDst::GetForwardStep(idx_1d_dst_1); move_tensor_coordinate( dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step_1)); } @@ -433,7 +449,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast using SpaceFillingCurve = SpaceFillingCurve, - SerpentineAccessPattern>; + SerpentineAccessPatternDst>; constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); if constexpr(num_access == 0)