From 44202b9d3260c56b93abdcafafc17db096b2fecd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Tue, 5 Aug 2025 15:12:36 +0000 Subject: [PATCH] WIP: Integration of packed cast into gridwise_gemm_xdl_cshuffle_conv_v3. --- .../gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 47 +++++++++--- .../tensor_operation/gpu/grid/packed_cast.hpp | 76 ++++++++++++++++++- .../threadwise_tensor_slice_transfer.hpp | 35 ++++++--- include/ck/utility/type_convert.hpp | 26 +++++++ 4 files changed, 161 insertions(+), 23 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 4c76035112..996ccd0953 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -112,9 +112,10 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 static constexpr bool is_gfx950_and_bf16_input_ = false; #endif - using CShuffleInputDataType = std::conditional_t; + // using CShuffleInputDataType = std::conditional_t; + using CShuffleInputDataType = AccDataType; __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch) { @@ -918,7 +919,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 1, InMemoryDataOperationEnum::Set, 1, - true>{ + true, + is_gfx950_and_bf16_input_>{ c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, make_multi_index(0, 0, @@ -988,15 +990,26 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 if constexpr (is_gfx950_and_bf16_input_) { - packed_cast(sfc_c_vgpr); + auto c_thread_packed_cast = PackedCast< + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + M2, + M4, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle + >{}; + c_thread_packed_cast.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin + c_thread_buf // source buffer + ); } // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); // make sure it's safe to read from LDS block_sync_lds(); @@ -1295,7 +1308,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 1, InMemoryDataOperationEnum::Set, 1, - true>{ + true, + is_gfx950_and_bf16_input_>{ c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, make_multi_index(0, 0, @@ -1365,7 +1379,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 if constexpr (is_gfx950_and_bf16_input_) { - packed_cast(sfc_c_vgpr); + auto c_thread_packed_cast = PackedCast< + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + M2, + M4, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle + >{}; + c_thread_packed_cast.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin + c_thread_buf // source buffer + ); } // each thread write its data from VGPR to LDS diff --git a/include/ck/tensor_operation/gpu/grid/packed_cast.hpp b/include/ck/tensor_operation/gpu/grid/packed_cast.hpp index f0b226aaae..6d0f42aef1 100644 --- a/include/ck/tensor_operation/gpu/grid/packed_cast.hpp +++ b/include/ck/tensor_operation/gpu/grid/packed_cast.hpp @@ -2,6 +2,8 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" #include "ck/host_utility/hip_check_error.hpp" @@ -9,9 +11,77 @@ namespace ck { - __host__ __device__ inline void packed_cast(const auto& sfc_c_vgpr) + template + struct PackedCast { - // This function is a placeholder for packed cast operations. - // For now, it does nothing. + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf) + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value, + "wrong! SrcSliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); + + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + + // Calculate total elements in this slice + constexpr index_t elements_per_slice = + CShuffleMXdlPerWavePerShuffle * CShuffleNXdlPerWavePerShuffle * M2 * M4; + + constexpr auto calculate_coords = [&](auto idx) constexpr { + constexpr index_t m4_offset = idx.value % M4; + constexpr index_t m2_offset = (idx.value / M4) % M2; + constexpr index_t n_xdl_offset = (idx.value / (M4 * M2)) % CShuffleNXdlPerWavePerShuffle; + constexpr index_t m_xdl_offset = idx.value / (M4 * M2 * CShuffleNXdlPerWavePerShuffle); + + return make_tuple( + src_slice_origin_idx[Number<0>{}] + Number{}, + src_slice_origin_idx[Number<1>{}] + Number{}, + Number<0>{}, + Number<0>{}, + src_slice_origin_idx[Number<4>{}] + Number{}, + Number<0>{}, + src_slice_origin_idx[Number<6>{}] + Number{}, + Number<0>{} + ); + }; + + constexpr index_t num_pairs = elements_per_slice / 2; + constexpr bool has_odd_element = (elements_per_slice % 2 == 1); + + static_for<0, num_pairs, 1>{}([&](auto pair_idx) { + constexpr auto idx_0 = Number{}; + constexpr auto idx_1 = Number{}; + + constexpr auto coord_0 = calculate_coords(idx_0); + constexpr auto coord_1 = calculate_coords(idx_1); + + float& val_0 = src_buf[coord_0]; + float& val_1 = src_buf[coord_1]; + + // Use packed conversion + static_cast_float_to_bhalf_packed(val_0, val_1); + }); + + // Handle last element if the number of elements is odd. + if constexpr (has_odd_element) + { + constexpr auto last_idx = Number{}; + constexpr auto last_coord = calculate_coords(last_idx); + + // Single element conversion + float& last_val = src_buf[last_coord]; + const auto single_bf16 = static_cast<__bf16>(last_val); + uint16_t* parts = reinterpret_cast(&last_val); + const uint16_t* bf16_bits = reinterpret_cast(&single_bf16); + parts[1] = bf16_bits[0]; + } + + }; }; } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 2305997f70..1305180ed4 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -34,11 +34,15 @@ template ::type = false> + typename enable_if::type = false, + bool PackedInput = false> struct ThreadwiseTensorSliceTransfer_v1r3 { static constexpr index_t nDim = SliceLengths::Size(); + static constexpr bool float_input_and_bf16_output_ = + std::is_same_v && std::is_same_v; + using Index = MultiIndex; using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); @@ -106,17 +110,30 @@ struct ThreadwiseTensorSliceTransfer_v1r3 // copy data from src_buf into dst_vector // TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve? - static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - constexpr index_t src_offset = src_desc.CalculateOffset( - src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + if constexpr (PackedInput && float_input_and_bf16_output_) + { + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - DstData v; + const float packed_float = src_buf[Number{}]; + const bhalf_t* bf16_array = reinterpret_cast(&packed_float); + dst_vector.template AsType()(i) = bf16_array[0]; + }); + } + else + { + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - // apply element-wise operation - element_op_(v, src_buf[Number{}]); + DstData v; - dst_vector.template AsType()(i) = v; - }); + // apply element-wise operation + element_op_(v, src_buf[Number{}]); + dst_vector.template AsType()(i) = v; + }); + } const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 9d428c2897..53a905a7cd 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -68,6 +68,32 @@ inline __device__ bhalf2_t static_cast_float2_to_bhalf2_rne(float2_t x) } #endif +// TODO: Why do we need the host instance? +inline __host__ __device__ void static_cast_float_to_bhalf_packed(float& x, float& y) +{ +#if defined(__gfx950__) + uint32_t result; + asm volatile("v_cvt_pk_bf16_f32 %0, %1, %2" + : "=v"(result) + : "v"(x), "v"(y)); + + // Extract individual BF16 values from packed result + const uint16_t* bf16_values = reinterpret_cast(&result); + + // Treat x and y as arrays of uint16_t + uint16_t* x_parts = reinterpret_cast(&x); + uint16_t* y_parts = reinterpret_cast(&y); + + // Store BF16 values directly to the upper 16 bits (index 1 on little-endian) + x_parts[1] = bf16_values[0]; + y_parts[1] = bf16_values[1]; +#else + // Skip conversion for non-GFX950 architectures + x = static_cast(static_cast(x)); + y = static_cast(static_cast(y)); +#endif +} + // Declare a template function for conversion of bf16 vector of two values using RNE template __host__ __device__ constexpr Y bf16x2_convert_rne(X x);