From cee7644c85f4c6b057376b70929a709727d801bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Tue, 12 Aug 2025 12:46:01 +0000 Subject: [PATCH] Working version 2 of the packed cast. --- .../gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 14 +-- .../tensor_operation/gpu/grid/packed_cast.hpp | 14 ++- .../threadwise_tensor_slice_transfer.hpp | 92 ++++++++++++++----- .../test_grouped_convnd_bwd_weight.cpp | 3 +- 4 files changed, 80 insertions(+), 43 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index a09147db2b..f822e8b2cc 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -800,15 +800,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 static_assert(std::is_default_constructible_v); auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); - if constexpr(is_gfx950_and_bf16_input_) - { - constexpr auto register_size_per_xdl_op = blockwise_gemm_pipeline.xdlops_gemm.GetRegSizePerXdlops(); - StaticBufferTupleOfVector c_thread_buf_bf16; - } const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / @@ -1010,8 +1001,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 c_thread_packed_cast.Run( c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct) sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin - c_thread_buf, // source buffer, - c_thread_buf_bf16 // destination buffer + c_thread_buf // source buffer ); } @@ -1390,7 +1380,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 if constexpr (is_gfx950_and_bf16_input_) { - auto c_thread_packed_cast = PackedCast< + auto c_thread_packed_cast = PackedCastV2< M2, M4, CShuffleMXdlPerWavePerShuffle, diff --git a/include/ck/tensor_operation/gpu/grid/packed_cast.hpp b/include/ck/tensor_operation/gpu/grid/packed_cast.hpp index 1264b515b9..111af2fb29 100644 --- a/include/ck/tensor_operation/gpu/grid/packed_cast.hpp +++ b/include/ck/tensor_operation/gpu/grid/packed_cast.hpp @@ -6,6 +6,7 @@ #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { @@ -94,6 +95,10 @@ namespace ck { template struct PackedCastV2 { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + template __device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf) { @@ -130,8 +135,8 @@ namespace ck { static_for<0, num_pairs, 1>{}([&](auto i_pair) { - constexpr auto idx_1d_0 = 2 * i_pair; - constexpr auto idx_1d_1 = 2 * i_pair + 1; + constexpr auto idx_1d_0 = I2 * i_pair; + constexpr auto idx_1d_1 = I2 * i_pair + I1; constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0); constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1); @@ -146,16 +151,17 @@ namespace ck { // Handle last element if the number of elements is odd. if constexpr (has_odd_element) { - constexpr auto last_idx_1d = num_access - 1; + constexpr auto last_idx_1d = Number{}; constexpr auto last_idx_md = SpaceFillingCurve::GetIndex(last_idx_1d); // Single element conversion constexpr auto last_src_offset = src_desc.CalculateOffset(last_idx_1d); - float& last_val = src_buf(Number{}); + float& last_val = src_buf(Number{}); const auto single_bf16 = static_cast<__bf16>(last_val); uint16_t* parts = reinterpret_cast(&last_val); const uint16_t* bf16_bits = reinterpret_cast(&single_bf16); parts[0] = bf16_bits[0]; } + }; }; } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 5926c90d58..bf3f0f91e9 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -35,10 +35,14 @@ template ::type = false, - bool PackedInput = false> + bool PackedInput = false, + typename enable_if::type = false> struct ThreadwiseTensorSliceTransfer_v1r3 { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr index_t nDim = SliceLengths::Size(); static constexpr bool float_input_and_bf16_output_ = @@ -86,38 +90,76 @@ struct ThreadwiseTensorSliceTransfer_v1r3 constexpr auto src_desc = remove_cvref_t{}; constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector? + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + typename vector_type_maker::type dst_vector; + using dst_vector_t = typename vector_type_maker::type::type; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); if constexpr (PackedInput && float_input_and_bf16_output_) { - static_assert(DstScalarPerVector == 2, "wrong! DstScalarPerVector must be 2"); static_assert(std::is_same_v>, "wrong! DimAccessOrder must be the identity sequence <0, 1, 2, 3, 4, 5, 6, 7>"); - // TODO: Fill the dst_vector vectprized access of size 2. - // Copying the dst_vector into dst_buf should be the same as before. + static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector"); + static_assert(1 == DstScalarPerVector, "wrong!1 != DstScalarPerVector"); + + static_for<0, num_access, 1>{}([&](auto idx_1d) + { + // We need map the odd indices to the even indices, since + // the even indices contain a packed bf16x2 value, where + // the first value contains the bf16 value for the corresponding even index + // and the second value contains the bf16 value for the odd index following the even index. + // The odd indices are not used, so we can just ignore them. + constexpr auto pair_index = idx_1d % I2; + constexpr auto idx_src_1d = idx_1d - pair_index; + + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_src_1d); + constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md); + + union + { + float src_float; + bhalf16_t src_bf16x2; + } packed_value; + + packed_value.src_float = src_buf[Number{}]; + dst_vector.template AsType()(I0) = packed_value.src_bf16x2[pair_index.value]; + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); } else { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_scalar_step_in_vector = - generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); - - using SpaceFillingCurve = SpaceFillingCurve>; - - // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector? - static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, - "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); - typename vector_type_maker::type dst_vector; - using dst_vector_t = typename vector_type_maker::type::type; - - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - static_for<0, num_access, 1>{}([&](auto idx_1d) { constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 79cf5c4f53..cf7cda34cb 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -227,7 +227,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D) } // Use --gtest_filter="TestGroupedConvndBwdWeight2d_bf16_gfx950/*" to only this subset of tests. -#if defined(__gfx950__) +// List tests in the test suite with --gtest_list_tests template class TestGroupedConvndBwdWeight2d_bf16_gfx950 : public TestGroupedConvndBwdWeight @@ -257,4 +257,3 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d_bf16_gfx950, Test2D) this->Run(); } -#endif