From ed047d08b4cd6e0afcbb58f473a4568e0dfa566d Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Thu, 29 May 2025 12:33:30 +0500 Subject: [PATCH] Support multiple D in GridwiseGemm_wmma_cshuffle_v3 DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters. --- .../impl/device_gemm_wmma_cshuffle_v3.hpp | 43 +- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 395 ++++++++++++------ 2 files changed, 300 insertions(+), 138 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp index 90afc467d4..ed34468c58 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp @@ -180,11 +180,13 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2, // DsLayout CLayout, ADataType, BDataType, AccDataType, CShuffleDataType, + Tuple<>, // DsDataType CDataType, AElementwiseOperation, BElementwiseOperation, @@ -219,7 +221,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, @@ -294,7 +296,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid, + HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid, 0, arg_.M * arg_.N * sizeof(CDataType), stream_config.stream_id_)); @@ -312,7 +314,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid, + HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_)); @@ -468,11 +470,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2{}, // p_ds_grid_ + p_c, + M, + N, + K, + StrideA, + StrideB, + std::array{}, // StrideDs_ + StrideC, + KBatch, + a_element_op, + b_element_op, + cde_element_op}; } static auto MakeInvoker() { return Invoker{}; } @@ -488,20 +504,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2(static_cast(p_a), static_cast(p_b), + std::array{}, // p_ds_grid_ static_cast(p_c), M, N, K, StrideA, StrideB, + std::array{}, // StrideDs_ StrideC, - KBatch); + KBatch, + a_element_op, + b_element_op, + c_element_op); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index f3354cd5dd..666599bf44 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -11,7 +11,7 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -19,7 +19,7 @@ namespace ck { template __global__ void @@ -31,22 +31,26 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) #if defined(__gfx11__) // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) { #endif __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( + GridwiseGemm::template Run( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset, p_shared, - karg); + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op); #if defined(__gfx11__) } #endif @@ -59,8 +63,8 @@ __global__ void /// /// @par Overview /// This GEMM kernel is carrying out following mathematical equation: -/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N})) -/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are +/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...) +/// Where A, B, Ds are input tensors and E is the output tensor. The A/B/CDE_op are /// elementwise operations that could be applied on each tensor respectively. /// The \"universal\" gemm comes with multiple pipelines optimized for different usage /// scenarios. That's why it's called \"universal\". It's universal through it's design @@ -73,18 +77,19 @@ __global__ void /// /// @tparam ALayout A tensor data layout. /// @tparam BLayout B tensor data layout. -/// @tparam CLayout C tensor data layout. +/// @tparam DsLayout D tensors data layout. +/// @tparam ELayout E tensor data layout. /// @tparam ADataType A tensor data type. /// @tparam BDataType B tensor data type. /// @tparam AccDataType The accumulation data type related to the hardware /// matrix-multiplication instruction. /// @tparam CShuffleDataType The data type used to store matrix-multiplication results into /// LDS memory during \"CShuffle\" data layout optimization. -/// @tparam CDataType C tensor data type. -/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. -/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. -/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor -/// (after GEMM). +/// @tparam EDataType E tensor data type. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after +/// GEMM) and D input tensors. /// @tparam GemmSpec Determines used "padding" version. /// @tparam BlockSize The number of threads within workgroup. /// @tparam MPerBlock The input/output data tile size in the M dimension. @@ -142,11 +147,12 @@ __global__ void /// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions /// results to process per wave per iteration of CShuffle /// in N dimension. -/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial /// thread distribution used for storing data into output /// tensor across output data layout dimensions. -/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. -/// Used when storing data to output tensor. +/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access. +/// Used when loading data from D tensors and storing data +/// to output tensor. /// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or /// intrawave). /// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. @@ -160,15 +166,17 @@ __global__ void /// in global memory (pre-shuffled). template {}; static constexpr auto I7 = Number<7>{}; + // TODO: remove + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + // K1 should be Number<...> static constexpr auto AK0Number = Number{}; static constexpr auto BK0Number = Number{}; @@ -530,17 +542,18 @@ struct GridwiseGemm_wmma_cshuffle_v3 return MakeWmmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); } - __host__ __device__ static auto - MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + template + __device__ static auto + MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE) { const auto c_grid_desc_mraw_nraw = [&]() { - if constexpr(is_same::value) + if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideDE, I1)); } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideDE)); } }(); @@ -593,6 +606,44 @@ struct GridwiseGemm_wmma_cshuffle_v3 #endif } + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeDEGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + struct Problem { __host__ Problem(index_t M_, @@ -600,14 +651,16 @@ struct GridwiseGemm_wmma_cshuffle_v3 index_t K_, index_t StrideA_, index_t StrideB_, - index_t StrideC_, + std::array StrideDs_, + index_t StrideE_, index_t KBatch_) : M{M_}, N{N_}, K{K_}, StrideA{StrideA_}, StrideB{StrideB_}, - StrideC{StrideC_}, + StrideDs{StrideDs_}, + StrideE{StrideE_}, KBatch{KBatch_}, MPadded{CalculateMPadded(M_)}, NPadded{CalculateNPadded(N_)}, @@ -627,8 +680,16 @@ struct GridwiseGemm_wmma_cshuffle_v3 << "N:" << N << ", " << "K:" << K << ", " << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " + << "SB:" << StrideB << ", "; + if constexpr(NumDTensor > 0) + { + std::cout << "SDs: { "; + static_for<0, NumDTensor, 1>{}([&](auto i) { + std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : ""); + }); + std::cout << " }, "; + } + std::cout << "SE:" << StrideE << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead << ", " @@ -644,7 +705,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 index_t K; index_t StrideA; index_t StrideB; - index_t StrideC; + std::array StrideDs; + index_t StrideE; index_t KBatch; index_t MPadded; index_t NPadded; @@ -661,21 +723,35 @@ struct GridwiseGemm_wmma_cshuffle_v3 { __host__ Argument(const ADataType* p_a_grid_, const BDataType* p_b_grid_, - CDataType* p_c_grid_, + std::array p_ds_grid_, + EDataType* p_e_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, - index_t StrideC_, + std::array StrideDs_, + index_t StrideE_, index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CDEElementwiseOperation cde_element_op_, bool is_reduce_ = false) - : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_}, p_a_grid{p_a_grid_}, p_b_grid{p_b_grid_}, - p_c_grid{p_c_grid_}, + p_ds_grid{}, + p_e_grid{p_e_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + cde_element_op{cde_element_op_}, is_reduce(is_reduce_) { + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); } __host__ __device__ inline bool IsReduceAdd() const @@ -690,42 +766,49 @@ struct GridwiseGemm_wmma_cshuffle_v3 const ADataType* p_a_grid; const BDataType* p_b_grid; - CDataType* p_c_grid; + DsGridPointer p_ds_grid; + EDataType* p_e_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CDEElementwiseOperation cde_element_op; + + // TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd bool is_reduce; }; struct SplitKBatchOffset { - __device__ SplitKBatchOffset(Argument& karg) + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) { if constexpr(is_same_v) { - a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; + a_k_split_offset = k_id * karg.KRead / APackedSize; } else if constexpr(is_same_v) { - a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + a_k_split_offset = k_id * karg.KRead * karg.StrideA; } if constexpr(is_same_v) { - b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + b_k_split_offset = k_id * karg.KRead * karg.StrideB; } else if constexpr(is_same_v) { if constexpr(!PermuteB) { - b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + b_k_split_offset = k_id * karg.KRead / BPackedSize; } else { const int k0_offset = karg.KRead * karg.N; - b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + b_k_split_offset = k_id * k0_offset / BPackedSize; } } - if(blockIdx.z < static_cast(karg.KBatch - 1)) + if(k_id < karg.KBatch - 1) { karg.K = karg.KRead; } @@ -736,7 +819,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 if(karg.IsReduceAdd()) { - c_reduce_offset = blockIdx.z * karg.M * karg.N; + c_reduce_offset = k_id * karg.M * karg.N; } else { @@ -1143,7 +1226,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + std::cout << "Arg K value is not a multiple of K_Batch * KPerBlock! K: " << karg.K << " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } @@ -1219,7 +1302,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 } } - if constexpr(is_same::value) + if constexpr(is_same::value) { if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) { @@ -1252,23 +1335,20 @@ struct GridwiseGemm_wmma_cshuffle_v3 } } - if constexpr(!(is_same, half_t>::value || - is_same, float>::value || - is_same, bhalf_t>::value || - is_same, int32_t>::value)) + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) { - if(!karg.IsReduceAdd()) + if(karg.IsAtomicAdd() && karg.KBatch > 1) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet" - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - if(karg.KBatch > 1) - { - return false; + std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported for this " + << "destination type (EDataType) " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; } + return false; } } @@ -1301,18 +1381,18 @@ struct GridwiseGemm_wmma_cshuffle_v3 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); } - template - __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + template + __device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock) { - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( - c_grid_desc_m_n, + const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + de_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), make_unmerge_transform(make_tuple(NBlock, Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - return c_grid_desc_mblock_mperblock_nblock_nperblock; + return de_grid_desc_mblock_mperblock_nblock_nperblock; } // return block_id to C matrix tile idx (m0, n0) mapping @@ -1322,30 +1402,40 @@ struct GridwiseGemm_wmma_cshuffle_v3 template __device__ static void Run(const ADataType* p_a_grid, const BDataType* p_b_grid, - CDataType* p_c_grid, + DsGridPointer p_ds_grid, + EDataType* p_e_grid, void* p_shared, const Problem& problem, const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock) + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - const AElementwiseOperation a_element_op{}; - const BElementwiseOperation b_element_op{}; - const CElementwiseOperation c_element_op{}; + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; @@ -1355,8 +1445,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 if(!block_2_ctile_map.ValidCTileIndex( block_work_idx, - make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) { return; } @@ -1483,7 +1573,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 c_thread_buf, num_k_block_main_loop); - // shuffle C and write out + // Epilogue: shuffle C for better memory access pattern, apply elementwise operation to + // C and Ds, write out result E to global memory { // C mapping in single thread. constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = @@ -1601,31 +1692,60 @@ struct GridwiseGemm_wmma_cshuffle_v3 m_thread_data_on_block_idx[I3]), ck::tensor_operation::element_wise::PassThrough{}}; - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor buffers + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // blockwise copy which loads C from LDS, D from global, applies elementwise + // operation and stores result E to global + auto cde_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, // ThreadGroup + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, // ElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // DstInMemOp, Sequence<1, CShuffleMRepeatPerShuffle * MWave * MPerWmma, 1, CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), - c_element_op}; + sequence_merge_t, + uniform_sequence_gen_t< + NumDTensor, + false>>, // bool ThreadTransferSrcResetCoordinateAfterRun, + Sequence> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; // space filling curve for local reg & global memory // space filling curve for threadwise C in VGPR @@ -1641,7 +1761,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 MAccVgprs>>{}; // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = + constexpr auto sfc_cde_global = SpaceFillingCurve, Sequence<0, 2, 1, 3>, Sequence<1, @@ -1651,7 +1771,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!"); static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS @@ -1668,57 +1788,78 @@ struct GridwiseGemm_wmma_cshuffle_v3 // make sure it's safe to read from LDS block_sync_lds(); - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); + // each block loads its C data from LDS, D from global, applies elementwise + // operation and stores result E to global + cde_shuffle_block_copy_lds_to_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); if constexpr(access_id < num_access - 1) { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_global_step); + }); - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + // move on E + cde_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), I0, cde_global_step); } }); } } template __device__ static void Run(const ADataType* p_a_grid, const BDataType* p_b_grid, - CDataType* p_c_grid, + DsGridPointer& p_ds_grid, + EDataType* p_e_grid, void* p_shared, - const Problem& problem) + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) { const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); - const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + const auto e_grid_desc_m_n = MakeDEGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE); + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n, problem.MBlock, problem.NBlock); Run(p_a_grid, p_b_grid, - p_c_grid, + p_ds_grid, + p_e_grid, p_shared, problem, a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock); + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + cde_element_op); } };