diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index e0bd84efb1..8a89fbd2fb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -898,7 +898,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_pass_through - >; + true>>; // shuffle: threadwise copy C from VGPR to LDS auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{ @@ -1007,20 +1007,20 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 // make sure it's safe to write to LDS block_sync_lds(); - if constexpr (is_gfx650_and_bf16_output()) - { - auto c_thread_packed_cast = PackedCastV2< - M2, - M4, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle - >{}; - c_thread_packed_cast.Run( - c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct) - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin - c_thread_buf // source buffer - ); - } + // if constexpr (is_gfx650_and_bf16_output()) + // { + // auto c_thread_packed_cast = PackedCastV2< + // M2, + // M4, + // CShuffleMXdlPerWavePerShuffle, + // CShuffleNXdlPerWavePerShuffle + // >{}; + // c_thread_packed_cast.Run( + // c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct) + // sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin + // c_thread_buf // source buffer + // ); + // } // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, @@ -1308,7 +1308,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_pass_through - >; + true>>; // shuffle: threadwise copy C from VGPR to LDS auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{ @@ -1417,20 +1417,20 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 // make sure it's safe to write to LDS block_sync_lds(); - if constexpr (is_gfx650_and_bf16_output()) - { - auto c_thread_packed_cast = PackedCastV2< - M2, - M4, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle - >{}; - c_thread_packed_cast.Run( - c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin - c_thread_buf // source buffer - ); - } + // if constexpr (is_gfx650_and_bf16_output()) + // { + // auto c_thread_packed_cast = PackedCastV2< + // M2, + // M4, + // CShuffleMXdlPerWavePerShuffle, + // CShuffleNXdlPerWavePerShuffle + // >{}; + // c_thread_packed_cast.Run( + // c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc + // sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin + // c_thread_buf // source buffer + // ); + // } // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 23d685747c..4e3f47ea88 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -1646,7 +1646,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_packed_cast< + ThreadwiseTensorSliceTransfer_v1r3_vectorized< AccDataType, CShuffleDataType, decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), @@ -2093,7 +2093,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_packed_cast< + ThreadwiseTensorSliceTransfer_v1r3_vectorized< AccDataType, CShuffleDataType, decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 52814fae66..3a14c1699b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -897,7 +897,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_pass_through< + ThreadwiseTensorSliceTransfer_v1r3_packed_cast< FloatAcc, FloatC, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), @@ -1002,20 +1002,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // make sure it's safe to do ds_write block_sync_lds(); - if constexpr (is_gfx650_and_bf16_output()) - { - auto c_thread_packed_cast = PackedCastV2< - M2, - M4, - CShuffleMRepeatPerShuffle, - CShuffleNRepeatPerShuffle - >{}; - c_thread_packed_cast.Run( - c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, // source desc (TensorDescriptor struct) - make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), // source slice origin - c_thread_buf // source buffer - ); - } + // if constexpr (is_gfx650_and_bf16_output()) + // { + // auto c_thread_packed_cast = PackedCastV2< + // M2, + // M4, + // CShuffleMRepeatPerShuffle, + // CShuffleNRepeatPerShuffle + // >{}; + // c_thread_packed_cast.Run( + // c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, // source desc (TensorDescriptor struct) + // make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), // source slice origin + // c_thread_buf // source buffer + // ); + // } // VGPR to LDS c_thread_copy_vgpr_to_lds.Run( diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 152ddb4e45..32ff815067 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -438,6 +438,185 @@ private: true> thread_scratch_; }; +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is StaticBuffer +// 3. SrcSliceOrginIdx is known at compile-time +// 2. dst: +// 1. DstDesc is not known at compile-time +// 2. DstBuffer is DynamicBuffer +// 3. DstSliceOrginIdx is not known at compile time +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v1r3_vectorized +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr bool SnakedAccess = false; + + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3_vectorized(const DstDesc& dst_desc, + const Index& dst_slice_origin_idx, + const ElementwiseOperation&) + : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + + // Assert that elementwise op is pass through. + static_assert( + std::is_same_v, ck::tensor_operation::element_wise::PassThrough>, + "wrong! ElementwiseOperation must be PassThrough"); + + // For now, SrcData must be float and DstData must be ck::bhalf_t + static_assert(std::is_same_v, + "wrong! SrcData must be float"); + static_assert(std::is_same_v, + "wrong! DstData must be bhalf_t"); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value, + "wrong! SrcSliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve, + SnakedAccess>; + + static_assert(2 == SpaceFillingCurve::ScalarPerVector, "wrong!2 != SpaceFillingCurve::ScalarPerVector"); + ck::bhalf2_t dst_vector; + using dst_vector_t = ck::bhalf2_t; + + constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess(); + static_for<0, num_access, 1>{}([&](auto idx_1d) + { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + constexpr auto idx_src_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md); + constexpr auto idx_src_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md + src_scalar_step_in_vector); + + const float val_0 = src_buf[Number{}]; + const float val_1 = src_buf[Number{}]; + dst_vector = bf16x2_convert_rne(val_0, val_1); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector); + + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve, + SnakedAccess>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } +private: + DstCoord dst_coord_; +}; + // Assume: // 1. src: // 1. SrcDesc is known at compile-time diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index e22db0dade..479b1d7751 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -106,10 +106,10 @@ inline __host__ __device__ void static_cast_float_to_bhalf_packed_v2(float& x, f bhalf2_t bf16x2; } converter; - // typedef __attribute__((__vector_size__(4))) __bf16 llvm_bf16x2_t; - // typedef __attribute__((__vector_size__(8))) float llvm_fp32x2_t; - // converter.bf16x2 = __builtin_convertvector(llvm_fp32x2_t{x, y}, llvm_bf16x2_t); - converter.bf16x2 = {bf16_convert_rtn(x), bf16_convert_rtn(y)}; + typedef __attribute__((__vector_size__(4))) __bf16 llvm_bf16x2_t; + typedef __attribute__((__vector_size__(8))) float llvm_fp32x2_t; + converter.bf16x2 = __builtin_convertvector(llvm_fp32x2_t{x, y}, llvm_bf16x2_t); + //converter.bf16x2 = {bf16_convert_rtn(x), bf16_convert_rtn(y)}; x = converter.fp32; }