From 4b8a559da9355cf69531fd33350a8e55db3ab7dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Wed, 6 Aug 2025 11:04:29 +0000 Subject: [PATCH] Fixed packed_cast implementation for slice access. --- .../gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 4 +- .../tensor_operation/gpu/grid/packed_cast.hpp | 57 +++++++++++-------- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 996ccd0953..629935b927 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -991,14 +991,13 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 if constexpr (is_gfx950_and_bf16_input_) { auto c_thread_packed_cast = PackedCast< - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), M2, M4, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle >{}; c_thread_packed_cast.Run( - c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct) sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin c_thread_buf // source buffer ); @@ -1380,7 +1379,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 if constexpr (is_gfx950_and_bf16_input_) { auto c_thread_packed_cast = PackedCast< - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), M2, M4, CShuffleMXdlPerWavePerShuffle, diff --git a/include/ck/tensor_operation/gpu/grid/packed_cast.hpp b/include/ck/tensor_operation/gpu/grid/packed_cast.hpp index 6d0f42aef1..89b31ca3eb 100644 --- a/include/ck/tensor_operation/gpu/grid/packed_cast.hpp +++ b/include/ck/tensor_operation/gpu/grid/packed_cast.hpp @@ -11,44 +11,47 @@ namespace ck { - template + template struct PackedCast { - template - __device__ void Run(const SrcDesc&, - const SrcSliceOriginIdx&, - const SrcBuffer& src_buf) + template + __device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf) { static_assert(SrcDesc::IsKnownAtCompileTime(), "wrong! SrcDesc need to known at compile-time"); - static_assert(is_known_at_compile_time>::value, "wrong! SrcSliceOrigin need to known at compile-time"); static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); + constexpr auto src_desc = remove_cvref_t{}; constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); // Calculate total elements in this slice constexpr index_t elements_per_slice = CShuffleMXdlPerWavePerShuffle * CShuffleNXdlPerWavePerShuffle * M2 * M4; - constexpr auto calculate_coords = [&](auto idx) constexpr { - constexpr index_t m4_offset = idx.value % M4; - constexpr index_t m2_offset = (idx.value / M4) % M2; - constexpr index_t n_xdl_offset = (idx.value / (M4 * M2)) % CShuffleNXdlPerWavePerShuffle; - constexpr index_t m_xdl_offset = idx.value / (M4 * M2 * CShuffleNXdlPerWavePerShuffle); - - return make_tuple( - src_slice_origin_idx[Number<0>{}] + Number{}, - src_slice_origin_idx[Number<1>{}] + Number{}, - Number<0>{}, - Number<0>{}, - src_slice_origin_idx[Number<4>{}] + Number{}, - Number<0>{}, - src_slice_origin_idx[Number<6>{}] + Number{}, - Number<0>{} - ); + constexpr auto calculate_coords = [src_slice_origin_idx](auto idx) constexpr { + + // We know that the access order is + // Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + // Sequence + + constexpr index_t m4_offset = idx.value % M4; + constexpr index_t m2_offset = (idx.value / M4) % M2; + constexpr index_t n_xdl_offset = (idx.value / (M4 * M2)) % CShuffleNXdlPerWavePerShuffle; + constexpr index_t m_xdl_offset = idx.value / (M4 * M2 * CShuffleNXdlPerWavePerShuffle); + + return make_tuple( + src_slice_origin_idx[Number<0>{}] + Number{}, + src_slice_origin_idx[Number<1>{}] + Number{}, + Number<0>{}, // this dim has unit size + Number<0>{}, // this dim has unit size + src_slice_origin_idx[Number<4>{}] + Number{}, + Number<0>{}, // this dim has unit size + src_slice_origin_idx[Number<6>{}] + Number{}, + Number<0>{} // this dim has unit size + ); }; constexpr index_t num_pairs = elements_per_slice / 2; @@ -60,9 +63,12 @@ namespace ck { constexpr auto coord_0 = calculate_coords(idx_0); constexpr auto coord_1 = calculate_coords(idx_1); + + constexpr auto offset_0 = src_desc.CalculateOffset(coord_0); + constexpr auto offset_1 = src_desc.CalculateOffset(coord_1); - float& val_0 = src_buf[coord_0]; - float& val_1 = src_buf[coord_1]; + float& val_0 = src_buf(Number{}); + float& val_1 = src_buf(Number{}); // Use packed conversion static_cast_float_to_bhalf_packed(val_0, val_1); @@ -75,7 +81,8 @@ namespace ck { constexpr auto last_coord = calculate_coords(last_idx); // Single element conversion - float& last_val = src_buf[last_coord]; + constexpr auto last_offset = src_desc.CalculateOffset(last_coord); + float& last_val = src_buf[Number{}]; const auto single_bf16 = static_cast<__bf16>(last_val); uint16_t* parts = reinterpret_cast(&last_val); const uint16_t* bf16_bits = reinterpret_cast(&single_bf16);