diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index ee059b204f..e0bd84efb1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -9,6 +9,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/packed_cast.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" @@ -897,7 +898,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_packed_cast{}; + c_thread_packed_cast.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct) + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin + c_thread_buf // source buffer + ); + } + // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), @@ -1292,7 +1308,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_packed_cast{}; + c_thread_packed_cast.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin + c_thread_buf // source buffer + ); + } + // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index b96f4b6e4d..aad793fe89 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -9,6 +9,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/packed_cast.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" @@ -1645,7 +1646,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_packed_cast< + ThreadwiseTensorSliceTransfer_v1r3_pass_through< AccDataType, CShuffleDataType, decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), @@ -1762,6 +1763,21 @@ struct GridwiseGemm_xdl_cshuffle_v3 // make sure it's safe to write to LDS block_sync_lds(); + if constexpr (is_gfx650_and_bf16_output()) + { + auto c_thread_packed_cast = PackedCastV2< + M2, + M4, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle + >{}; + c_thread_packed_cast.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct) + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin + c_thread_buf // source buffer + ); + } + // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), @@ -2077,7 +2093,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_packed_cast< + ThreadwiseTensorSliceTransfer_v1r3_pass_through< AccDataType, CShuffleDataType, decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), @@ -2188,6 +2204,21 @@ struct GridwiseGemm_xdl_cshuffle_v3 // make sure it's safe to write to LDS block_sync_lds(); + if constexpr (is_gfx650_and_bf16_output()) + { + auto c_thread_packed_cast = PackedCastV2< + M2, + M4, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle + >{}; + c_thread_packed_cast.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin + c_thread_buf // source buffer + ); + } + // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index a1f42ffcb6..52814fae66 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -8,6 +8,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/packed_cast.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" @@ -896,7 +897,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_packed_cast< + ThreadwiseTensorSliceTransfer_v1r3_pass_through< FloatAcc, FloatC, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), @@ -1001,6 +1002,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // make sure it's safe to do ds_write block_sync_lds(); + if constexpr (is_gfx650_and_bf16_output()) + { + auto c_thread_packed_cast = PackedCastV2< + M2, + M4, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle + >{}; + c_thread_packed_cast.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, // source desc (TensorDescriptor struct) + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), // source slice origin + c_thread_buf // source buffer + ); + } + // VGPR to LDS c_thread_copy_vgpr_to_lds.Run( c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, diff --git a/include/ck/tensor_operation/gpu/grid/packed_cast.hpp b/include/ck/tensor_operation/gpu/grid/packed_cast.hpp new file mode 100644 index 0000000000..7e0d31ea62 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/packed_cast.hpp @@ -0,0 +1,141 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + + template + struct PackedCast + { + template + __device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf) + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + static_assert(is_known_at_compile_time>::value, + "wrong! SrcSliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); + + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + + // Calculate total elements in this slice + constexpr index_t elements_per_slice = + CShuffleMXdlPerWavePerShuffle * CShuffleNXdlPerWavePerShuffle * M2 * M4; + + constexpr auto calculate_coords = [src_slice_origin_idx](auto idx) constexpr { + + // We know that the access order is + // Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + // Sequence + + constexpr index_t m4_offset = idx.value % M4; + constexpr index_t m2_offset = (idx.value / M4) % M2; + constexpr index_t n_xdl_offset = (idx.value / (M4 * M2)) % CShuffleNXdlPerWavePerShuffle; + constexpr index_t m_xdl_offset = idx.value / (M4 * M2 * CShuffleNXdlPerWavePerShuffle); + + return make_tuple( + src_slice_origin_idx[Number<0>{}] + Number{}, + src_slice_origin_idx[Number<1>{}] + Number{}, + Number<0>{}, // this dim has unit size + Number<0>{}, // this dim has unit size + src_slice_origin_idx[Number<4>{}] + Number{}, + Number<0>{}, // this dim has unit size + src_slice_origin_idx[Number<6>{}] + Number{}, + Number<0>{} // this dim has unit size + ); + }; + + constexpr index_t num_pairs = elements_per_slice / 2; + constexpr bool has_odd_element = (elements_per_slice % 2 == 1); + + static_assert(!has_odd_element, "PackedCast does not support odd number of elements"); + + static_for<0, num_pairs, 1>{}([&](auto pair_idx) { + constexpr auto idx_0 = Number{}; + constexpr auto idx_1 = Number{}; + + constexpr auto coord_0 = calculate_coords(idx_0); + constexpr auto coord_1 = calculate_coords(idx_1); + + constexpr auto offset_0 = src_desc.CalculateOffset(coord_0); + constexpr auto offset_1 = src_desc.CalculateOffset(coord_1); + + float& val_0 = src_buf(Number{}); + const float val_1 = src_buf[Number{}]; + + static_cast_float_to_bhalf_packed_v2(val_0, val_1); + }); + + }; + }; + + + template + struct PackedCastV2 + { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + template + __device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf) + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + static_assert(is_known_at_compile_time>::value, + "wrong! SrcSliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); + + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + + using SliceLengths = Sequence; + using DimAccessOrder = Sequence<0, 1, 2, 3, 4, 5, 6, 7>; + using DstScalarPerAccess = Sequence<1, 1, 1, 1, 1, 1, 1, 1>; + using SpaceFillingCurve = SpaceFillingCurve; + + static_assert(SpaceFillingCurve::ScalarPerVector == 1, + "wrong! SpaceFillingCurve::ScalarPerVector must be 1 for PackedCastV2"); + + constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess(); + constexpr index_t num_pairs = num_access / 2; + constexpr bool has_odd_element = (num_access % 2 == 1); + + static_assert(!has_odd_element, "PackedCastV2 does not support odd number of elements"); + + static_for<0, num_pairs, 1>{}([&](auto i_pair) + { + constexpr auto idx_1d_0 = I2 * i_pair; + constexpr auto idx_1d_1 = I2 * i_pair + I1; + constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0); + constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1); + + constexpr index_t src_offset_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_0); + constexpr index_t src_offset_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_1); + + float& val_0 = src_buf(Number{}); + const float val_1 = src_buf[Number{}]; + static_cast_float_to_bhalf_packed_v2(val_0, val_1); + }); + }; + }; +} diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 119b723be1..15668f06e6 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -13,6 +13,200 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" namespace ck { +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is StaticBuffer +// 3. SrcSliceOrginIdx is known at compile-time +// 2. dst: +// 1. DstDesc is not known at compile-time +// 2. DstBuffer is DynamicBuffer +// 3. DstSliceOrginIdx is not known at compile time +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v1r3_pass_through +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr bool float_input_and_bf16_output_ = + std::is_same_v && std::is_same_v; + + using Index = MultiIndex; + + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3_pass_through(const DstDesc& dst_desc, + const Index& dst_slice_origin_idx, + const ElementwiseOperation& element_op) + : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)), + element_op_{element_op} + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value, + "wrong! SrcSliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector? + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + typename vector_type_maker::type dst_vector; + using dst_vector_t = typename vector_type_maker::type::type; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_assert(std::is_same_v>, + "wrong! DimAccessOrder must be the identity sequence <0, 1, 2, 3, 4, 5, 6, 7>"); + + static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector"); + static_assert(1 == DstScalarPerVector, "wrong!1 != DstScalarPerVector"); + + static_for<0, num_access, 1>{}([&](auto idx_1d) + { + // We need map the odd indices to the even indices, since + // the even indices contain a packed bf16x2 value, where + // the first value contains the bf16 value for the corresponding even index + // and the second value contains the bf16 value for the odd index following the even index. + // The odd indices are not used, so we can just ignore them. + constexpr auto pair_index = idx_1d % I2; + constexpr auto idx_src_1d = idx_1d - pair_index; + + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_src_1d); + constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md); + + union + { + float src_float; + bhalf2_t src_bf16x2; + } packed_value; + + packed_value.src_float = src_buf[Number{}]; + + DstData v; + element_op_(v, packed_value.src_bf16x2[pair_index.value]); + dst_vector.template AsType()(I0) = v; + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; // namespace ThreadwiseTensorSliceTransfer_v1r3_pass_through // Assume: // 1. src: @@ -292,7 +486,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector? static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, - "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); typename vector_type_maker::type dst_vector; using dst_vector_t = typename vector_type_maker::type::type; @@ -305,12 +499,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3 // TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve? static_for<0, DstScalarPerVector, 1>{}([&](auto i) { constexpr index_t src_offset = src_desc.CalculateOffset( - src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); DstData v; // apply element-wise operation element_op_(v, src_buf[Number{}]); + dst_vector.template AsType()(i) = v; }); @@ -379,6 +574,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); } + private: DstCoord dst_coord_; const ElementwiseOperation element_op_; diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index d016ffba4f..95ba91ae20 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -91,6 +91,25 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(fl #endif } +/** + * @brief Converts two floats into a vector of 2 16-bit bfloat types (bhalf2_t) using + * rounding to nearest/even (RNE). + * + * @param x First float value. The packed value is stored here. + * @param y Second float value. + */ +inline __host__ __device__ void static_cast_float_to_bhalf_packed_v2(float& x, float y) +{ + union + { + float fp32; + bhalf2_t bf16x2; + } converter; + + converter.bf16x2 = {bf16_convert_rtn(x), bf16_convert_rtn(y)}; + x = converter.fp32; +} + /** * @brief Converts two floats into a vector of 2 16-bit bfloat types (bhalf2_t) using * rounding to nearest/even (RNE). diff --git a/test/data_type/test_bhalf.cpp b/test/data_type/test_bhalf.cpp index b2fc813d33..bf5135ebb8 100644 --- a/test/data_type/test_bhalf.cpp +++ b/test/data_type/test_bhalf.cpp @@ -15,45 +15,6 @@ using ck::bhalf_t; using ck::type_convert; -TEST(BHALF_T, Nan) -{ - const uint16_t binary_bhalf_nan = 0x7FC0; - const bhalf_t bhalf_nan = ck::bit_cast(binary_bhalf_nan); - EXPECT_EQ(bhalf_nan, type_convert(ck::NumericLimits::QuietNaN())); -} - -TEST(BHALF_T, Inf) -{ - const uint16_t binary_bhalf_inf = 0x7F80; - const bhalf_t bhalf_inf = ck::bit_cast(binary_bhalf_inf); - EXPECT_EQ(bhalf_inf, type_convert(ck::NumericLimits::Infinity())); -} - -TEST(BHALF_T, MantisaOverflow) -{ - const float abs_tol = std::pow(2, -7); - const uint32_t val = 0x81FFFFFF; - const float float_val = ck::bit_cast(val); - - ASSERT_NEAR(float_val, type_convert(type_convert(float_val)), abs_tol); -} - -TEST(BHALF_T, ExpOverflow) -{ - const uint32_t val = 0xFF800000; - const float float_val = ck::bit_cast(val); - ASSERT_EQ(type_convert(type_convert(float_val)), float_val); -} - -TEST(BHALF_T, MantisaExpOverflow) -{ - const uint32_t val = 0xFFFFFFFF; - const float float_val = ck::bit_cast(val); - - ASSERT_TRUE(std::isnan(float_val)); - ASSERT_TRUE(std::isnan(type_convert(type_convert(float_val)))); -} - __global__ void cast_roundtrip(const float2 input, float2* output) { const ck::bhalf2_t bhalf2_val = ck::bf16x2_convert_rne(input.x, input.y); @@ -74,10 +35,26 @@ __global__ void cast(const float input, float* output) *output = type_convert(bhalf_val); } +__global__ void packed_cast_in_place(const float x1, const float x2, ck::bhalf2_t* output) +{ + union + { + float src; + ck::bhalf2_t dst; + } converter; + + float x = x1; + ck::static_cast_float_to_bhalf_packed_v2(x, x2); + + converter.src = x; + *output = converter.dst; +} + enum struct CastMode : int { Standard = 0, - Packed = 1 + Packed = 1, + PackedInPlace = 2 }; template @@ -119,6 +96,60 @@ __global__ void test_performance_kernel(float* input, ck::bhalf_t* output) } } +template +__global__ void test_in_place_performance_kernel(float* input, ck::bhalf_t* output) +{ + ck::bhalf_t buffer_bf16[NumElements]; + float buffer_float[NumElements]; + + // Initialize input data + for(int i = 0; i < NumElements; i++) + { + buffer_float[i] = input[i]; + } + + if constexpr (PackedCast == CastMode::PackedInPlace) + { + union + { + float src; + ck::bhalf2_t dst; + } workspace; + + for(int i = 0; i < NumElements; i++) + { + for (int j = 0; j < NumElements; j++) + { + int index = (i + j) % NumElements; + index = index < NumElements - 1 ? index : NumElements - 2; + workspace.src = buffer_float[i]; + ck::static_cast_float_to_bhalf_packed_v2(workspace.src, buffer_float[j]); + ck::bhalf2_t* buffer_range = reinterpret_cast(&buffer_bf16[index]); + *buffer_range = workspace.dst; + } + } + } + else + { + for(int i = 0; i < NumElements; i++) + { + for (int j = 0; j < NumElements; j++) + { + int index = (i + j) % NumElements; + index = index < NumElements - 1 ? index : NumElements - 2; + buffer_bf16[index] = ck::bf16_convert_rtn_base(buffer_float[i]); + buffer_bf16[index + 1] = ck::bf16_convert_rtn_base(buffer_float[j]); + } + } + } + + // Copy results back to output + for(int i = 0; i < NumElements; i++) + { + output[i] = buffer_bf16[i]; + } +} + template void run_performance_test() { @@ -165,6 +196,91 @@ void run_performance_test() ASSERT_LT(packed_time, baseline_time); } +template +void run_in_place_performance_test() +{ + float* input_dev; + ck::bhalf_t* output_dev; + std::vector output_host(NumElements); + + hip_check_error(hipMalloc(&input_dev, sizeof(float) * NumElements)); + hip_check_error(hipMalloc(&output_dev, sizeof(ck::bhalf_t) * NumElements)); + + // Initialize input data on the device + std::vector input_host(NumElements); + for (int i = 0; i < NumElements; i++) + { + input_host[i] = 3.14f * static_cast(i) - 1.7f; + } + + hip_check_error(hipMemcpy(input_dev, input_host.data(), sizeof(float) * NumElements, hipMemcpyHostToDevice)); + + StreamConfig stream_config; + stream_config.time_kernel_ = true; + + auto baseline_kernel = test_in_place_performance_kernel; + auto packed_kernel = test_in_place_performance_kernel; + + constexpr dim3 grid_size(1); + constexpr dim3 block_size(1); + constexpr size_t shared_mem_size = 0; + + const float baseline_time = launch_and_time_kernel(stream_config, baseline_kernel, grid_size, block_size, shared_mem_size, input_dev, output_dev); + hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost)); + + const float packed_time = launch_and_time_kernel(stream_config, packed_kernel, grid_size, block_size, shared_mem_size, input_dev, output_dev); + hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost)); + + // Cleanup + hip_check_error(hipFree(input_dev)); + hip_check_error(hipFree(output_dev)); + + std::cout << "Packed cast time ( " << NumElements << " elements): " << packed_time << " ms" << std::endl; + std::cout << "Baseline cast time ( " << NumElements << " elements): " << baseline_time << " ms" << std::endl; + + // Check if packed cast is faster than baseline + ASSERT_LT(packed_time, baseline_time); +} + +TEST(BHALF_T, Nan) +{ + const uint16_t binary_bhalf_nan = 0x7FC0; + const bhalf_t bhalf_nan = ck::bit_cast(binary_bhalf_nan); + EXPECT_EQ(bhalf_nan, type_convert(ck::NumericLimits::QuietNaN())); +} + +TEST(BHALF_T, Inf) +{ + const uint16_t binary_bhalf_inf = 0x7F80; + const bhalf_t bhalf_inf = ck::bit_cast(binary_bhalf_inf); + EXPECT_EQ(bhalf_inf, type_convert(ck::NumericLimits::Infinity())); +} + +TEST(BHALF_T, MantisaOverflow) +{ + const float abs_tol = std::pow(2, -7); + const uint32_t val = 0x81FFFFFF; + const float float_val = ck::bit_cast(val); + + ASSERT_NEAR(float_val, type_convert(type_convert(float_val)), abs_tol); +} + +TEST(BHALF_T, ExpOverflow) +{ + const uint32_t val = 0xFF800000; + const float float_val = ck::bit_cast(val); + ASSERT_EQ(type_convert(type_convert(float_val)), float_val); +} + +TEST(BHALF_T, MantisaExpOverflow) +{ + const uint32_t val = 0xFFFFFFFF; + const float float_val = ck::bit_cast(val); + + ASSERT_TRUE(std::isnan(float_val)); + ASSERT_TRUE(std::isnan(type_convert(type_convert(float_val)))); +} + TEST(BHALF_T, Performance) { if (ck::get_device_name() == "gfx950") @@ -365,3 +481,41 @@ TEST(BHALF_T, CastOnDevice) ASSERT_NEAR(float_val_after_cast_host, -float_vals[idx], abs_tol); } } + +TEST(BHALF_T, PackedCast_in_place) +{ + const float v1 = 3.14f; + const float v2 = -1.618f; + ck::bhalf2_t* bhalf2_val_d; + hip_check_error(hipMalloc(&bhalf2_val_d, sizeof(ck::bhalf2_t))); + + packed_cast_in_place<<<1, 1>>>(v1, v2, bhalf2_val_d); + hip_check_error(hipGetLastError()); + + ck::bhalf2_t bhalf2_val_h; + hip_check_error(hipMemcpy(&bhalf2_val_h, bhalf2_val_d, sizeof(ck::bhalf2_t), hipMemcpyDeviceToHost)); + + // Convert back to floats + const float fval1 = type_convert(bhalf2_val_h[0]); + const float fval2 = type_convert(bhalf2_val_h[1]); + + const float abs_tol = std::pow(2, -7); + ASSERT_NEAR(fval1, v1, abs_tol); + ASSERT_NEAR(fval2, v2, abs_tol); + + hip_check_error(hipFree(bhalf2_val_d)); +} + +TEST(BHALF_T, PackedCast_in_place_performance) +{ + if (ck::get_device_name() == "gfx950") + { + run_in_place_performance_test<32>(); + run_in_place_performance_test<64>(); + run_in_place_performance_test<128>(); + } + else + { + GTEST_SKIP() << "Packed cast performance test requires gfx950."; + } +}