From 86c8bef5d763f81dc5ea89bab434f126e64d1070 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Tue, 3 Jun 2025 14:54:20 +0000 Subject: [PATCH] Refactor thread_copy_lds_direct_load; fix gfx942 direct lds load example; fix f16_pki4 example --- .../01_gemm/gemm_xdl_lds_direct_load_fp16.cpp | 2 +- .../splitK_gemm_xdl_lds_direct_load_fp16.cpp | 2 +- ...ipeline_xdlops_b_preshuffle_dequant_v3.hpp | 2 +- ...roup_tensor_slice_transfer_direct_load.hpp | 34 ++++++++++++------- ...m_xdl_splitk_c_shuffle_lds_direct_load.hpp | 2 ++ ...ultiple_d_xdl_cshuffle_lds_direct_load.hpp | 14 ++++---- ...ise_gemm_xdlops_splitk_lds_direct_load.hpp | 33 ++++++++++++------ .../threadwise_tensor_slice_transfer_util.hpp | 12 +++++++ .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 9 ----- 9 files changed, 68 insertions(+), 42 deletions(-) diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp index 62037f7740..b020101342 100644 --- a/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp @@ -38,7 +38,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>; // clang-format on #else // clang-format off diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp index 97a3f89e5e..d6b5a90cfc 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp @@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 16, 16, 16, 1, 1, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, 1, 1, S<1, 32, 1, 4>, 4>; + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>; // clang-format on #else diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp index e5fe92a50d..8b227a8aa1 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp @@ -145,7 +145,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3{}; + XdlopsGemm{}; static constexpr index_t PrefetchStages = 2; static constexpr index_t PrefillStages = 1; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp index 6854eaafab..89917d8f6e 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp @@ -50,8 +50,7 @@ template + index_t ScalarPerVector> struct ThreadGroupTensorSliceTransfer_DirectLoad { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); @@ -68,20 +67,12 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad static constexpr auto block_slice_lengths = BlockSliceLengths{}; static constexpr auto thread_cluster_lengths = ThreadClusterLengths{}; - static constexpr auto wave_thread_cluster_lengths = - Sequence{}; - static constexpr auto wave_cluster_lengths = - Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{}; static constexpr auto thread_single_load_size = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); // After a load, each thread moves by `thread_steps` instead of loading the next elements. // It makes the whole wavefront load contiguous memory, what is required for direct loads. - static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; - static constexpr auto wave_single_load_size = - wave_thread_cluster_lengths * thread_single_load_size; + static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps; static __device__ constexpr bool AreThreadClusterLengthsValid() @@ -180,6 +171,25 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId())); + constexpr auto wave_cluster_lengths = generate_sequence_v2( + [&](auto i) { + if constexpr(ThreadClusterArrangeOrder{}.At(i) == (nDim - 3)) + { + return Number{}; + } + else + { + return I1; + } + }, + Number{}); + + constexpr auto wave_thread_cluster_lengths = ThreadClusterLengths{} / wave_cluster_lengths; + constexpr auto wave_single_load_size = + wave_thread_cluster_lengths * thread_single_load_size; + constexpr auto wave_cluster_desc_ = + make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{}); + const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex( make_multi_index(ThreadGroup::GetThreadId() / 64)); @@ -327,8 +337,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad private: static constexpr auto thread_cluster_desc_ = make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); - static constexpr auto wave_cluster_desc_ = - make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{}); SrcCoord src_coord_; DstCoord dst_coord_; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp index d704d04054..eda966c48a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp @@ -98,10 +98,12 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); + return make_naive_tensor_descriptor(make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(AK1, Number{}, I1)); } __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { // B matrix in LDS memory, destination of blockwise copy. - return make_naive_tensor_descriptor( - make_tuple(BK0PerBlock, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); + return make_naive_tensor_descriptor(make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(BK1, Number{}, I1)); } __host__ __device__ static constexpr auto @@ -566,10 +564,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ThreadGroupTensorSliceTransfer_DirectLoad, ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferSrcAccessOrder, ADataType, AComputeDataType, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, 2, ABlockTransferScalarPerVector>( @@ -582,10 +582,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ThreadGroupTensorSliceTransfer_DirectLoad, BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferSrcAccessOrder, BDataType, BComputeDataType, decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, 2, BBlockTransferScalarPerVector>( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp index bac8c32886..3e23008a5f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp @@ -76,10 +76,12 @@ template {}; // K1 should be Number<...> - static constexpr auto K1 = Number{}; - static constexpr auto M01 = 1; - static constexpr auto N01 = 1; + static constexpr auto K1 = Number{}; + static constexpr auto KPerBlock = Number{}; + static constexpr auto M01 = 1; + static constexpr auto N01 = 1; static constexpr auto gemm_padder = tensor_operation::device::GemmPadder{ @@ -613,8 +616,9 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(K1, Number{}, I1)); } }(); @@ -630,9 +634,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load } else { - return make_naive_tensor_descriptor_aligned( + return make_naive_tensor_descriptor( make_tuple(Number<1>{}, Number{}, Number{}, K1), - max_lds_align); + make_tuple( + Number{} * Number{}, K1, Number{}, I1)); } }(); // B matrix in LDS memory, dst of blockwise copy @@ -645,8 +650,9 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(K1, Number{}, I1)); } }(); @@ -662,9 +668,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load } else { - return make_naive_tensor_descriptor_aligned( + return make_naive_tensor_descriptor( make_tuple(Number<1>{}, Number{}, Number{}, K1), - max_lds_align); + make_tuple( + Number{} * Number{}, K1, Number{}, I1)); } }(); @@ -672,10 +679,12 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load ThreadGroupTensorSliceTransfer_DirectLoad, ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferSrcAccessOrder, FloatA, ComputeType, decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, 3, ABlockTransferSrcScalarPerVector>( @@ -688,10 +697,12 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load ThreadGroupTensorSliceTransfer_DirectLoad, BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferSrcAccessOrder, FloatB, ComputeType, decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, 3, BBlockTransferSrcScalarPerVector>( diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp index 96b95579f5..168f028e2a 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp @@ -62,6 +62,18 @@ struct lambda_scalar_per_access_for_src_and_dst } }; +template +struct lambda_wave_cluster_dimension +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if((nDim - i) == 3) + return WaveNum; + else + return 1; + } +}; + } // namespace detail } // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 9248af0a4b..b41224e4b4 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1159,15 +1159,6 @@ struct MfmaSelector #endif } - // Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t) - // See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3 - // TODO: explore optimization opportunity by using new mfma instructions on gfx950 - template <> - constexpr auto GetMfma() - { - return MfmaInstr::mfma_f32_32x32x16f8f8; - } - template <> constexpr auto GetMfma() {