From 481df169f281b126ebb5bfc232899c607959f641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Wed, 27 Aug 2025 11:27:03 +0000 Subject: [PATCH] Add packed cast to gridwise gemm multi d. --- .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 179 +++++++++++++----- 1 file changed, 135 insertions(+), 44 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index b72c4d0313..983b83dc13 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -8,6 +8,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/packed_cast.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" @@ -1307,6 +1308,19 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 c_grid_desc_m_n); } + // check if we should apply gf950 device specific optimization for BF16 output + __device__ static constexpr bool is_gfx650_and_bf16_output() + { +#if defined(__gfx950__) + return + !DoElementwiseBeforeCShuffle && + std::is_same_v && + std::is_same_v; +#else + return false; +#endif + } + template , + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>, + ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + conditional_t, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>>; + // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - conditional_t, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, make_multi_index(0, 0, m_thread_data_on_block_idx[I1], @@ -1718,6 +1756,21 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 // make sure it's safe to write to LDS block_sync_lds(); + if constexpr (is_gfx650_and_bf16_output()) + { + auto c_thread_packed_cast = PackedCastV2< + M2, + M4, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle + >{}; + c_thread_packed_cast.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct) + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin + c_thread_buf // source buffer + ); + } + // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), @@ -2111,28 +2164,51 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 }; // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - conditional_t, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + using ThreadwiseTransfer = std::conditional_t< + is_gfx650_and_bf16_output(), + ThreadwiseTensorSliceTransfer_v1r3_pass_through< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>, + ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + conditional_t, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>>; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, make_multi_index(0, 0, m_thread_data_on_block_idx[I1], @@ -2250,6 +2326,21 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 // make sure it's safe to write to LDS block_sync_lds(); + if constexpr (is_gfx650_and_bf16_output()) + { + auto c_thread_packed_cast = PackedCastV2< + M2, + M4, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle + >{}; + c_thread_packed_cast.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin + c_thread_buf // source buffer + ); + } + // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id),