diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index aad793fe89..23d685747c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -1646,7 +1646,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_pass_through< + ThreadwiseTensorSliceTransfer_v1r3_packed_cast< AccDataType, CShuffleDataType, decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), @@ -1763,20 +1763,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 // make sure it's safe to write to LDS block_sync_lds(); - if constexpr (is_gfx650_and_bf16_output()) - { - auto c_thread_packed_cast = PackedCastV2< - M2, - M4, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle - >{}; - c_thread_packed_cast.Run( - c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct) - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin - c_thread_buf // source buffer - ); - } + // if constexpr (is_gfx650_and_bf16_output()) + // { + // auto c_thread_packed_cast = PackedCastV2< + // M2, + // M4, + // CShuffleMXdlPerWavePerShuffle, + // CShuffleNXdlPerWavePerShuffle + // >{}; + // c_thread_packed_cast.Run( + // c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct) + // sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin + // c_thread_buf // source buffer + // ); + // } // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, @@ -2093,7 +2093,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_pass_through< + ThreadwiseTensorSliceTransfer_v1r3_packed_cast< AccDataType, CShuffleDataType, decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), @@ -2204,20 +2204,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 // make sure it's safe to write to LDS block_sync_lds(); - if constexpr (is_gfx650_and_bf16_output()) - { - auto c_thread_packed_cast = PackedCastV2< - M2, - M4, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle - >{}; - c_thread_packed_cast.Run( - c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin - c_thread_buf // source buffer - ); - } + // if constexpr (is_gfx650_and_bf16_output()) + // { + // auto c_thread_packed_cast = PackedCastV2< + // M2, + // M4, + // CShuffleMXdlPerWavePerShuffle, + // CShuffleNXdlPerWavePerShuffle + // >{}; + // c_thread_packed_cast.Run( + // c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc + // sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin + // c_thread_buf // source buffer + // ); + // } // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 15668f06e6..d89f5f191c 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -235,12 +235,8 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; - static constexpr index_t SrcScalarPerVector = 1; static constexpr index_t BufScalarPerVector = 2; - // We cannot use SnakedCurve for packed cast, otherwise the vectorized store will not work correctly. - static constexpr bool SnakedCurve = false; - static constexpr index_t nDim = SliceLengths::Size(); using Index = MultiIndex; @@ -295,17 +291,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast constexpr auto src_desc = remove_cvref_t{}; constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); - // SFC to access the source buffer and destination buffers. using SpaceFillingCurve = SpaceFillingCurve, - SnakedCurve>; - - typename vector_type_maker::type buf_vector; - using buf_vector_t = typename vector_type_maker::type::type; + remove_cv_t>; static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector"); @@ -329,25 +320,35 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast const float val_0 = src_buf[Number{}]; const float val_1 = src_buf[Number{}]; - buf_vector.template AsType()(I0) = bf16x2_convert_rne(val_0, val_1); + const ck::bhalf2_t packed_value= bf16x2_convert_rne(val_0, val_1); const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); - // copy data from buf_vector into dst_buf - dst_buf.template Update( + // Store the first of the packed values + dst_buf.template Update( dst_coord_.GetOffset(), is_dst_valid, - buf_vector.template AsType()[Number<0>{}]); + packed_value[0]); + // Move to next dst coordinate + constexpr auto forward_step_0 = SpaceFillingCurve::GetForwardStep(idx_1d_0); + move_tensor_coordinate( + dst_desc, dst_coord_, + make_tensor_coordinate_step(dst_desc, forward_step_0)); + + // Store the second of the packed values + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + packed_value[1]); + + // Move to next dst coordinate, unless this was the last pair if constexpr(i_pair.value != num_pairs - 1) { - // Move two steps forward in the space-filling curve. - // This works only if we don't use the snaked access pattern. - constexpr auto forward_step_md = SpaceFillingCurve::GetStepBetween(idx_1d_0, Number{}); - + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d_1); move_tensor_coordinate( - dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step_md)); + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); } }); @@ -369,8 +370,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast using SpaceFillingCurve = SpaceFillingCurve, - SnakedCurve>; + remove_cv_t>; constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); if constexpr(num_access == 0) diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 95ba91ae20..e22db0dade 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -106,6 +106,9 @@ inline __host__ __device__ void static_cast_float_to_bhalf_packed_v2(float& x, f bhalf2_t bf16x2; } converter; + // typedef __attribute__((__vector_size__(4))) __bf16 llvm_bf16x2_t; + // typedef __attribute__((__vector_size__(8))) float llvm_fp32x2_t; + // converter.bf16x2 = __builtin_convertvector(llvm_fp32x2_t{x, y}, llvm_bf16x2_t); converter.bf16x2 = {bf16_convert_rtn(x), bf16_convert_rtn(y)}; x = converter.fp32; } @@ -119,10 +122,11 @@ inline __host__ __device__ void static_cast_float_to_bhalf_packed_v2(float& x, f * @return Converted vector of 2 bhalf_t. */ template<> -inline __host__ __device__ constexpr bhalf2_t bf16x2_convert_rne(float x, float y) +inline __host__ __device__ bhalf2_t bf16x2_convert_rne(float x, float y) { - // for gfx950, the compiler will use device instruction v_cvt_pk_bf16_f32 to execute packed cast. - return {bf16_convert_rtn(x), bf16_convert_rtn(y)}; + typedef __attribute__((__vector_size__(4))) __bf16 llvm_bf16x2_t; + typedef __attribute__((__vector_size__(8))) float llvm_fp32x2_t; + return __builtin_convertvector(llvm_fp32x2_t{x, y}, llvm_bf16x2_t); } // convert fp16 to bfp16 via fp32 with RTN if higher precision is needed