From 71bd07a7837e39aab7bae8f519c13c786191aebb Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Fri, 31 Oct 2025 19:19:26 +0100 Subject: [PATCH] WMMA gemm_add_relu_add_layernorm (#2989) * Summary: - Refactor epilogue (with CShuffle) to support fused operations: - EpilogueCShuffleBase holds common parts - EpilogueCShuffle: runs CShuffle and write out - EpilogueWelfordCShuffle: holds Welford specific arguments, runs CShuffle, write out, Welford first part and Welford write out - Extend thread transfer v7r3: - Support for intermediate data type different from src and dst type - New functionality to write to dst buffer and keep data (to be able to use them for additional operations) * Adress review comments [ROCm/composable_kernel commit: 4ebc48a3cdd2e46732815f6de0b11c7856936c57] --- .../blockwise_gemm_pipeline_wmmaops_base.hpp | 4 +- ...hread_group_tensor_slice_transfer_v7r3.hpp | 46 +- .../device_batched_gemm_wmma_cshuffle_v3.hpp | 9 +- ..._batched_gemm_wmma_cshuffle_v3_b_scale.hpp | 9 +- ..._multiple_d_layernorm_wmma_cshuffle_v3.hpp | 896 ++++++++++++++++++ .../gpu/element/element_wise_operation.hpp | 13 +- .../epilogue_cshuffle_v3_welford_wmma.hpp | 510 ++++++++++ .../gpu/grid/epilogue_cshuffle_v3_wmma.hpp | 195 ++++ .../grid/epilogue_cshuffle_v3_wmma_base.hpp | 253 +++++ .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 24 +- ...gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp | 24 +- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 331 ++----- .../threadwise_tensor_slice_transfer_v7r3.hpp | 133 ++- include/ck_tile/host/tensor_shuffle_utils.hpp | 0 .../gpu/gemm_add_relu_add_layernorm.hpp | 94 +- .../CMakeLists.txt | 7 +- ..._layernorm_f16_km_kn_mn_mn_mn_instance.cpp | 108 +++ ..._layernorm_f16_km_nk_mn_mn_mn_instance.cpp | 108 +++ ..._layernorm_f16_mk_kn_mn_mn_mn_instance.cpp | 108 +++ ..._layernorm_f16_mk_nk_mn_mn_mn_instance.cpp | 105 ++ ...ofile_gemm_add_relu_add_layernorm_impl.hpp | 16 +- test/gemm_layernorm/CMakeLists.txt | 12 +- ...test_gemm_add_relu_add_layernorm_fp16.cpp} | 5 - 23 files changed, 2678 insertions(+), 332 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp mode change 100755 => 100644 include/ck_tile/host/tensor_shuffle_utils.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp rename test/gemm_layernorm/{test_gemm_add_relu_add_layernorm_fp16_xdl.cpp => test_gemm_add_relu_add_layernorm_fp16.cpp} (96%) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index 28c871ae0d..265db9166a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -292,13 +292,15 @@ struct BlockwiseGemmWmmaops_pipeline_base make_tuple(Number{}, I1, I1, Number{}, I1, I1, NAccVgprs)); } + static constexpr auto MAccVgprs = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()[I2]; + __host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() { constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); - constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; return make_naive_tensor_descriptor( // |MRepeat |MWave |MSubGroup |NRepeat |NWave diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp index 46d0c6ac2e..47924204a4 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp @@ -42,7 +42,8 @@ template + index_t NumThreadScratch = 1, + typename InterDatas = DstDatas> struct ThreadGroupTensorSliceTransfer_v7r3 { static constexpr index_t nDim = @@ -97,7 +98,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), "wrong! ThreadGroup::GetNumOfThread() too small"); - if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() || ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( @@ -123,7 +124,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3 const SrcBuffers& src_bufs, Number thread_scratch_id = Number{}) { - if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() || ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); @@ -138,7 +139,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3 DstBuffers dst_bufs, Number thread_scratch_id = Number{}) { - if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() || ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { if constexpr(is_detected::value) @@ -148,6 +149,36 @@ struct ThreadGroupTensorSliceTransfer_v7r3 } } + template + __device__ void + RunWriteAndStoreVgpr(const DstDescs& dst_descs, + DstBuffers dst_bufs, + const DstVgprDescs& dst_vgpr_desc, + DstVgprBuffers dst_vgpr_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() || + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + if constexpr(is_detected::value && + is_detected::value) + threadwise_transfer_.RunWriteAndStoreVgpr( + dst_descs, dst_bufs, dst_vgpr_desc, dst_vgpr_buf, thread_scratch_id); + else if constexpr(is_detected::value) + threadwise_transfer_.RunWriteAndStoreVgpr( + dst_descs, dst_bufs, dst_vgpr_desc, tie(dst_vgpr_buf), thread_scratch_id); + else if constexpr(is_detected::value) + threadwise_transfer_.RunWriteAndStoreVgpr( + dst_descs, tie(dst_bufs), dst_vgpr_desc, dst_vgpr_buf, thread_scratch_id); + else + threadwise_transfer_.RunWriteAndStoreVgpr( + dst_descs, tie(dst_bufs), dst_vgpr_desc, tie(dst_vgpr_buf), thread_scratch_id); + } + } + template __device__ void Run(const SrcDescs& src_descs, const SrcBuffers& src_bufs, @@ -162,7 +193,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3 __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) { - if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() || ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step); @@ -179,7 +210,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3 __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, Number iDst, const Index& step) { - if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() || ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step); @@ -212,7 +243,8 @@ struct ThreadGroupTensorSliceTransfer_v7r3 DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, - NumThreadScratch>; + NumThreadScratch, + InterDatas>; ThreadwiseTransfer threadwise_transfer_; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp index e305dbfd9a..5542449a25 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp @@ -60,7 +60,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const long_index_t c_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -82,6 +84,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; }); + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + GridwiseGemm::template Run( p_as_grid_shift, p_bs_grid_shift, @@ -91,7 +95,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg, karg.a_element_op, karg.b_element_op, - karg.cde_element_op); + karg.cde_element_op, + epilogue_args); #if defined(__gfx11__) } #endif diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp index 4f676528bc..74dd75d6ef 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -46,12 +46,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); // The normal approach to batching would be to increase the grid size by just stretching out // the grid Z dimension (which is the outermost dimension), but this depends on lower level // functions not directly using the Z dimension for other calculations. As it turns out, k // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now // we will use the grid Y dimension for batching. This may be a bit fragile. - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[LDS_size]; const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); @@ -84,6 +86,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; }); + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + GridwiseGemm::template Run( p_as_grid_shift, p_bs_grid_shift, @@ -94,7 +98,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg, karg.a_element_op, karg.b_element_op, - karg.cde_element_op); + karg.cde_element_op, + epilogue_args); #if defined(__gfx11__) } #endif diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..780b799060 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp @@ -0,0 +1,896 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_multiple_d_welford_first_half_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, + EMeanVarDataType* __restrict__ p_welford_mean_grid, + EMeanVarDataType* __restrict__ p_welford_var_grid, + int32_t* __restrict__ p_welford_count_grid) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueWelfordCShuffle>(); + + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + auto epilogue_args = typename GridwiseGemm::EpilogueWelfordCShuffle( + p_welford_mean_grid, p_welford_var_grid, p_welford_count_grid, karg.M, karg.N); + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg, epilogue_args); + +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = p_welford_mean_grid; + ignore = p_welford_var_grid; + ignore = p_welford_count_grid; +#endif +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_welford_layernorm2d_second_half( + const EMeanVarDataType* __restrict__ p_e_grid, + const EMeanVarDataType* __restrict__ p_in_welford_mean_grid, + const EMeanVarDataType* __restrict__ p_in_welford_var_grid, + const int32_t* __restrict__ p_in_welford_count_grid, + const GammaDataType* __restrict__ p_gamma_grid, + const BetaDataType* __restrict__ p_beta_grid, + HDataType* __restrict__ p_h_grid, + const EHGridDesc_M_N e_grid_desc_m_n, + const EHGridDesc_M_N h_grid_desc_m_n, + const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock, + const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock, + const GammaBetaGridDesc_N gamma_grid_desc_n, + const GammaBetaGridDesc_N beta_grid_desc_n, + index_t numMeanVarCountBlockTileIteration_N, + index_t NBlockClusterLength, + ComputeDataType epsilon, + HElementwiseOperation h_element_op) +{ + GridwiseWelfordLayernorm::Run(p_e_grid, + p_in_welford_mean_grid, + p_in_welford_var_grid, + p_in_welford_count_grid, + p_gamma_grid, + p_beta_grid, + p_h_grid, + e_grid_desc_m_n, + h_grid_desc_m_n, + mean_var_grid_desc_m_nblock, + count_grid_desc_m_nblock, + gamma_grid_desc_n, + beta_grid_desc_n, + numMeanVarCountBlockTileIteration_N, + NBlockClusterLength, + epsilon, + h_element_op); +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3 + : public DeviceGemmMultipleDLayernorm +{ + // EDataType, MeanDataType and VarDataType must be the same. + using DeviceOp = DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t LayernormHDstVectorSize = CDEShuffleBlockTransferScalarPerVector; + static constexpr index_t LayernormGammaSrcVectorSize = CDEShuffleBlockTransferScalarPerVector; + static constexpr index_t LayernormBetaSrcVectorSize = CDEShuffleBlockTransferScalarPerVector; + static constexpr index_t LayernormESrcVectorSize = CDEShuffleBlockTransferScalarPerVector; + static constexpr index_t LayernormThreadSliceSize_N = CDEShuffleBlockTransferScalarPerVector; + + using LayernormBlockTileSize_M_N = + Sequence; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using CDEShuffleBlockTransferScalarPerVectors = + Sequence; + + // GEMM + Welford 1st part kernel + using GridwiseGemmWelford = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + HLayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + DsDataType, + EMeanVarDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + // Welford 2nd part kernel + template + static auto MakeEHGridDescriptor_M_N(index_t M, index_t N, index_t Stride) + { + // Only support row major for E and H + const auto grid_desc_m_n = + make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(Stride, I1)); + return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{}); + } + + template + static auto MakeDescriptor_X(index_t X) + { + const auto grid_desc_x = make_naive_tensor_descriptor_packed(make_tuple(X)); + return PadTensorDescriptor(grid_desc_x, make_tuple(XPerTile), Sequence{}); + } + + using LayernormMeanVarGridDesc_M_NBlock = + decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N< + Sequence, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>(1, 1)); + + using LayernormCountGridDesc_M_NBlock = + decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N< + Sequence, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>(1, 1)); + + using GammaBetaGridDesc_N = decltype(MakeDescriptor_X(1)); + using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N, 1, 1>(1, 1, 1)); + + using GridwiseWelfordLayernorm = + GridwiseWelfordSecondHalfLayernorm2d; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + const void* p_gamma_grid, + const void* p_beta_grid, + void* p_h_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideH, + double epsilon, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + HElementwiseOperation h_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_workspace_e_grid_{nullptr}, + p_workspace_mean_{nullptr}, + p_workspace_var_{nullptr}, + p_workspace_count_{nullptr}, + p_gamma_grid_{static_cast(p_gamma_grid)}, + p_beta_grid_{static_cast(p_beta_grid)}, + p_h_grid_{static_cast(p_h_grid)}, + layernorm_e_grid_desc_m_n_{ + DeviceOp::MakeEHGridDescriptor_M_N, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>( + MRaw, NRaw, StrideH)}, + layernorm_mean_var_grid_desc_m_nblock_{}, + layernorm_count_grid_desc_m_nblock_{}, + gamma_grid_desc_n_{ + DeviceOp::MakeDescriptor_X(NRaw)}, + beta_grid_desc_n_{ + DeviceOp::MakeDescriptor_X(NRaw)}, + h_grid_desc_m_n_{ + DeviceOp::MakeEHGridDescriptor_M_N, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>( + MRaw, NRaw, StrideH)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + h_element_op_{h_element_op}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw}, + StrideA_{StrideA}, + StrideB_{StrideB}, + StrideDs_{StrideDs}, + StrideH_{StrideH}, + gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)}, + epsilon_{static_cast(epsilon)} + { + static_for<0, NumDTensor, 1>{}([&](auto i) { p_ds_grid_[i] = p_ds_grid[i]; }); + + layernorm_mean_var_grid_desc_m_nblock_ = + GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N< + Sequence, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_); + + layernorm_count_grid_desc_m_nblock_ = + GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N< + Sequence, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_); + } + + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + std::array p_ds_grid_; + void* p_workspace_e_grid_; + void* p_workspace_mean_; + void* p_workspace_var_; + void* p_workspace_count_; + const GammaDataType* p_gamma_grid_; + const BetaDataType* p_beta_grid_; + HDataType* p_h_grid_; + + // tensor descriptors (Welford second half) + EHGridDesc_M_N layernorm_e_grid_desc_m_n_; + LayernormMeanVarGridDesc_M_NBlock layernorm_mean_var_grid_desc_m_nblock_; + LayernormCountGridDesc_M_NBlock layernorm_count_grid_desc_m_nblock_; + GammaBetaGridDesc_N gamma_grid_desc_n_; + GammaBetaGridDesc_N beta_grid_desc_n_; + EHGridDesc_M_N h_grid_desc_m_n_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + HElementwiseOperation h_element_op_; + + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + index_t StrideA_; + index_t StrideB_; + std::array StrideDs_; + index_t StrideH_; + index_t gemm_nblock_; + AccDataType epsilon_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + typename GridwiseGemmWelford::Argument gemm_arg{ + std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + arg.p_ds_grid_, + static_cast(arg.p_workspace_e_grid_), + arg.MRaw_, + arg.NRaw_, + arg.KRaw_, + std::array{arg.StrideA_}, // StrideAs + std::array{arg.StrideB_}, // StrideBs + arg.StrideDs_, // StrideDs + arg.StrideH_, // StrideE + I1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; + + if(stream_config.log_level_ > 0) + { + gemm_arg.Print(); + GridwiseGemmWelford::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemmWelford::CheckValidity(gemm_arg)) + { + throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting"); + } + + if(arg.p_workspace_e_grid_ == nullptr || arg.p_workspace_mean_ == nullptr || + arg.p_workspace_var_ == nullptr || arg.p_workspace_count_ == nullptr) + throw std::runtime_error("wrong! WorkSpace pointer has not been set"); + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + GridwiseGemmWelford::CalculateGridSize(arg.MRaw_, arg.NRaw_, 1); + + float ave_time = 0; + + index_t K_split = (arg.KRaw_ + KPerBlock - 1) / KPerBlock * KPerBlock; + + const bool has_main_k_block_loop = + GridwiseGemmWelford::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel_gemm_welford_first_half) { + // Note: cache flushing not supported + + const auto kernel_welford_second_half = + kernel_welford_layernorm2d_second_half; + + // First kernel launch: GEMM + Welford first part + ave_time += + launch_and_time_kernel(stream_config, + kernel_gemm_welford_first_half, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_)); + + // Second kernel launch: Welford second part + const auto M = arg.h_grid_desc_m_n_.GetLength(I0); + const auto N = arg.h_grid_desc_m_n_.GetLength(I1); + + index_t MBlockClusterLength = + math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0)); + index_t NBlockClusterLength = + math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(1)); + + auto grid_size = MBlockClusterLength * NBlockClusterLength; + + index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil( + arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1)); + + ave_time += launch_and_time_kernel( + stream_config, + kernel_welford_second_half, + dim3(grid_size), + dim3(BlockSize), + 0, + static_cast(arg.p_workspace_e_grid_), + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_), + arg.p_gamma_grid_, + arg.p_beta_grid_, + arg.p_h_grid_, + arg.layernorm_e_grid_desc_m_n_, + arg.h_grid_desc_m_n_, + arg.layernorm_mean_var_grid_desc_m_nblock_, + arg.layernorm_count_grid_desc_m_nblock_, + arg.gamma_grid_desc_n_, + arg.beta_grid_desc_n_, + numMeanVarCountBlockTileIteration_N, + NBlockClusterLength, + arg.epsilon_, + arg.h_element_op_); + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + const auto kernel = kernel_gemm_multiple_d_welford_first_half_wmma_cshuffle_v3< + GridwiseGemmWelford, + EMeanVarDataType, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + const auto kernel = kernel_gemm_multiple_d_welford_first_half_wmma_cshuffle_v3< + GridwiseGemmWelford, + EMeanVarDataType, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + size_t workspace_size = 0; + + int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_; + + // workspace for welford intermediate mean + workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 128; + + // workspace for welford intermediate variance + workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 128; + + // workspace for welford intermediate count + workspace_size += pArg_->gemm_nblock_ * sizeof(int32_t) + 128; + + if constexpr(!is_same_v) + workspace_size += pArg_->MRaw_ * pArg_->NRaw_ * sizeof(EMeanVarDataType); + + return (workspace_size); + }; + + void SetWorkSpacePointer(BaseArgument* pArg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + + int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_; + + // setup buffer used for intermediate welford mean + pArg_->p_workspace_mean_ = static_cast(pArg_->p_workspace_); + + index_t mean_space_sz = gemm_welford_size * sizeof(EMeanVarDataType); + mean_space_sz = math::integer_least_multiple(mean_space_sz, 128); + + // setup buffer used for intermediate welford variance + pArg_->p_workspace_var_ = reinterpret_cast(pArg_->p_workspace_mean_) + mean_space_sz; + + index_t variance_space_sz = gemm_welford_size * sizeof(EMeanVarDataType); + variance_space_sz = math::integer_least_multiple(variance_space_sz, 128); + + // setup buffer used for intermediate welford count + pArg_->p_workspace_count_ = + reinterpret_cast(pArg_->p_workspace_var_) + variance_space_sz; + + index_t count_space_sz = gemm_welford_size * sizeof(int32_t); + count_space_sz = math::integer_least_multiple(count_space_sz, 128); + + if constexpr(!is_same_v) + pArg_->p_workspace_e_grid_ = + reinterpret_cast(pArg_->p_workspace_count_) + count_space_sz; + else + pArg_->p_workspace_e_grid_ = static_cast(pArg_->p_h_grid_); + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + // No need to check for splitK because we force KBatch = 1 (no support) + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((arg.KRaw_ % AK1 != 0 || arg.KRaw_ % BK1 != 0) && + !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + typename GridwiseGemmWelford::Argument gemm_arg{ + std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + arg.p_ds_grid_, + static_cast(arg.p_workspace_e_grid_), + arg.MRaw_, + arg.NRaw_, + arg.KRaw_, + std::array{arg.StrideA_}, // StrideAs + std::array{arg.StrideB_}, // StrideBs + arg.StrideDs_, // StrideDs + arg.StrideH_, // StrideE + I1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; + + const auto a_grid_desc_ak0_m_ak1 = + GridwiseGemmWelford::MakeAsGridDescriptor_AK0_M_AK1(gemm_arg.M, + gemm_arg.MPadded, + gemm_arg.K, + gemm_arg.KPadded, + gemm_arg.StrideAs, + gemm_arg.AK0); + const auto b_grid_desc_bk0_n_bk1 = + GridwiseGemmWelford::MakeBsGridDescriptor_BK0_N_BK1(gemm_arg.K, + gemm_arg.KPadded, + gemm_arg.N, + gemm_arg.NPadded, + gemm_arg.StrideBs, + gemm_arg.BK0); + + const auto M = a_grid_desc_ak0_m_ak1[I0].GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1[I0].GetLength(I1); + const auto K = + a_grid_desc_ak0_m_ak1[I0].GetLength(I0) * a_grid_desc_ak0_m_ak1[I0].GetLength(I2); + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } + + return GridwiseGemmWelford::CheckValidity(gemm_arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + const void* p_gamma, + const void* p_beta, + void* p_h, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideH, + double epsilon, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + HElementwiseOperation h_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_gamma, + p_beta, + p_h, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideH, + epsilon, + a_element_op, + b_element_op, + cde_element_op, + h_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + const void* p_gamma, + const void* p_beta, + void* p_h, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideH, + double epsilon, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + HElementwiseOperation h_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_gamma, + p_beta, + p_h, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideH, + epsilon, + a_element_op, + b_element_op, + cde_element_op, + h_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3" + << ">" + << "BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x"< - __host__ __device__ constexpr void operator()( - half_t& y, const float& x0, const half_t& x1, const half_t& x2) const + __host__ __device__ constexpr void operator()( + float& y, const float& x0, const half_t& x1, const half_t& x2) const { float a = x0 + x1; float b = a > 0 ? a : 0; @@ -69,6 +69,15 @@ struct AddReluAdd y = c; } + template <> + __host__ __device__ constexpr void operator()( + half_t& y, const float& x0, const half_t& x1, const half_t& x2) const + { + float y_float; + (*this)(y_float, x0, x1, x2); + y = y_float; + } + template <> __host__ __device__ constexpr void operator()( bhalf_t& y, const float& x0, const bhalf_t& x1, const bhalf_t& x2) const diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp new file mode 100644 index 0000000000..85d13538cc --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp @@ -0,0 +1,510 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" + +namespace ck { + +template +struct EpilogueWelfordCShuffle + : EpilogueCShuffleBase +{ + using Base = EpilogueCShuffleBase< + DsDataType, + EDataType, + AccDataType, + CShuffleDataType, + MPerBlock, + NPerBlock, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + CDEElementwiseOperation, + ThisThreadBlock, + BlockwiseGemmPipe>; + + using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat; + using Base::GetCShuffleLDSDescriptor; + using Base::GetVgprToLDSEpilogueDescriptor; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::NumDTensor; + + template + __host__ __device__ static auto MakeMeanVarDescriptor_M_N(index_t M, index_t N) + { + const auto grid_desc_m_n = + make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); + return tensor_operation::device::PadTensorDescriptor( + grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{}); + } + + template + __host__ __device__ static auto MakeCountDescriptor_M_N(index_t M, index_t N) + { + // We will broadcast [N] to [M, N] in this descriptor + // Hence, 1st stride is 0 + const auto grid_desc_m_n = + make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1)); + return tensor_operation::device::PadTensorDescriptor( + grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{}); + } + + template + __host__ __device__ static constexpr auto + MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n) + { + const auto M = grid_desc_m_n.GetLength(I0); + const auto NBlock = grid_desc_m_n.GetLength(I1); + const auto MBlock = M / MPerBlock; + + const auto grid_desc_mblock_mperblock_nblock = transform_tensor_descriptor( + grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_pass_through_transform(NBlock)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{})); + + return grid_desc_mblock_mperblock_nblock; + } + + using GemmMeanVarGridDesc_M_N = + decltype(MakeMeanVarDescriptor_M_N, MPerBlock, 1>(1, 1)); + + using GemmCountGridDesc_M_N = + decltype(MakeCountDescriptor_M_N, MPerBlock, 1>(1, 1)); + + __device__ EpilogueWelfordCShuffle(EDataType* p_welford_mean_grid_, + EDataType* p_welford_var_grid_, + int32_t* p_welford_count_grid_, + index_t MRaw_, + index_t NRaw_) + : p_welford_mean_grid(p_welford_mean_grid_), + p_welford_var_grid(p_welford_var_grid_), + p_welford_count_grid(p_welford_count_grid_), + NRaw(NRaw_) + { + index_t gemm_nblock = math::integer_divide_ceil(NRaw_, NPerBlock); + + gemm_mean_var_grid_desc_m_nblock = + MakeMeanVarDescriptor_M_N, MPerBlock, 1>(MRaw_, gemm_nblock); + + gemm_count_grid_desc_m_nblock = + MakeCountDescriptor_M_N, MPerBlock, 1>(MRaw_, gemm_nblock); + } + + template + __device__ void Run(CThreadBuf& c_thread_buf, + DsGridPointer p_ds_grid, + EDataType* p_e_grid, + void* p_shared, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + CDEElementwiseOperation& cde_element_op, + const index_t& block_m_id, + const index_t& block_n_id) + { + // Vmem buffers + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + auto mean_var_grid_desc_mblock_mperblock_nblock = + MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( + gemm_mean_var_grid_desc_m_nblock); + + auto mean_grid_buf = make_dynamic_buffer( + p_welford_mean_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize()); + + auto var_grid_buf = make_dynamic_buffer( + p_welford_var_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize()); + + auto count_grid_desc_mblock_mperblock_nblock = + MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(gemm_count_grid_desc_m_nblock); + auto welford_count_grid_buf = make_dynamic_buffer( + p_welford_count_grid, count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize()); + + // LDS buffer + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize()); + + // tuple of reference to C/Ds tensor buffers (mix LDS and Vmem) + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // Thread transfer Vgpr to LDS + auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor(); + + // Space Filling Curve Vgpr + constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{}; + + // Space Filling Curve Vmem + constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{}; + + // C thread descriptor + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + BlockwiseGemmPipe:: + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // Thread transfer LDS to Vmem + auto cde_shuffle_block_copy_lds_and_global = + Base::template GetLDSToVmemEpilogueDescriptor( + c_ds_desc_refs, + e_grid_desc_mblock_mperblock_nblock_nperblock, + cde_element_op, + block_m_id, + block_n_id); + + // Block descriptor + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + GetCShuffleLDSDescriptor(); + + // E Vgpr buffer + constexpr index_t PostShuffleThreadSliceSize_M = + (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) / + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I1); + + constexpr index_t PostShuffleThreadSliceSize_N = + (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) / + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3); + + constexpr auto PostShuffleThreadSliceSize_M_N = + Sequence{}; + + // Welford + constexpr auto post_shuffle_thread_desc_m_n = + make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, + Number{}, + Number<1>{}, + Number{})); + + auto e_thread_buf = make_static_buffer( + post_shuffle_thread_desc_m_n.GetElementSpaceSize()); + + using PostShuffleThreadClusterSize_M_N = Sequence< + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I1), + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3)>; + + constexpr auto post_shuffle_thread_cluster_desc = + make_cluster_descriptor(PostShuffleThreadClusterSize_M_N{}, Sequence<0, 1>{}); + + const auto post_shuffle_thread_cluster_idx = + post_shuffle_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto post_shuffle_thread_data_idx_begin = + post_shuffle_thread_cluster_idx * PostShuffleThreadSliceSize_M_N; + + constexpr auto thread_welford_src_desc_m_k = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{})); + + constexpr auto thread_welford_dst_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using ThreadwiseWelford = ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford, + false>; + + constexpr int num_shuffleM = + MPerBlock / (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma); + + constexpr int num_shuffleN = + NPerBlock / (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma); + + using mean_var_vgpr_type = decltype(make_static_buffer( + thread_welford_dst_desc_m.GetElementSpaceSize())); + + using welford_count_vgpr_type = + decltype(make_static_buffer( + thread_welford_dst_desc_m.GetElementSpaceSize())); + + Array threadwise_welfords; + Array mean_thread_bufs; + Array var_thread_bufs; + Array welford_count_thread_bufs; + + int max_count = PostShuffleThreadSliceSize_N * num_shuffleN; + const auto nblock = mean_var_grid_desc_mblock_mperblock_nblock.GetLength(I2); + + // tail block + if(block_n_id % nblock == nblock - 1) + { + constexpr index_t NPerShuffleBlock = + CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma; + + int NPerBlockTail = NRaw - NPerBlock * (nblock - 1); + int thread_max_len = + PostShuffleThreadSliceSize_N * (post_shuffle_thread_cluster_idx[I1] + 1); + int shuffle_step = 0; + while(thread_max_len <= NPerBlockTail && shuffle_step < num_shuffleN) + { + ++shuffle_step; + thread_max_len += NPerShuffleBlock; + } + + int delta = 0; + if(thread_max_len - NPerBlockTail > PostShuffleThreadSliceSize_N) + delta = 0; + else if(NPerBlockTail > thread_max_len) + delta = PostShuffleThreadSliceSize_N; + else + delta = PostShuffleThreadSliceSize_N - thread_max_len + NPerBlockTail; + + max_count = shuffle_step * PostShuffleThreadSliceSize_N + delta; + } + + // Initialize Welford + static_for<0, num_shuffleM, 1>{}([&](auto i) { + threadwise_welfords(i).max_count_ = max_count; + mean_thread_bufs(i) = make_static_buffer( + thread_welford_dst_desc_m.GetElementSpaceSize()); + + var_thread_bufs(i) = make_static_buffer( + thread_welford_dst_desc_m.GetElementSpaceSize()); + + welford_count_thread_bufs(i) = make_static_buffer( + thread_welford_dst_desc_m.GetElementSpaceSize()); + + static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) { + mean_thread_bufs(i)(j) = type_convert(0.0f); + var_thread_bufs(i)(j) = type_convert(0.0f); + welford_count_thread_bufs(i)(j) = 0; + }); + }); + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!"); + + // Run CShuffle + Store E + Welford threadwise + int shuffleM_index = __builtin_amdgcn_readfirstlane(0); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to read from LDS + block_sync_lds(); + + // each thread shuffle data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to write to LDS + block_sync_lds(); + + // Read LDS / Vmem + CDE elementwise operation + cde_shuffle_block_copy_lds_and_global.RunRead(c_ds_desc_refs, c_ds_buf_refs); + + // Store to Vmem, but keep data in Vgpr for Welford + cde_shuffle_block_copy_lds_and_global.RunWriteAndStoreVgpr( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf), + tie(post_shuffle_thread_desc_m_n), + tie(e_thread_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_global_step); + }); + + // move on E + cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step); + } + + // Threadwise welford + auto& threadwise_welford = threadwise_welfords(shuffleM_index); + auto& mean_thread_buf = mean_thread_bufs(shuffleM_index); + auto& var_thread_buf = var_thread_bufs(shuffleM_index); + + threadwise_welford.Run(e_thread_buf, mean_thread_buf, var_thread_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto de_global_step = sfc_cde_global.GetForwardStep(access_id); + constexpr int shuffleMInc = + de_global_step[I1] / + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength( + I1); + shuffleM_index = __builtin_amdgcn_readfirstlane(shuffleM_index + shuffleMInc); + } + }); + + // Blockwise welford and write out + static_for<0, num_shuffleM, 1>{}([&](auto i) { + auto& mean_thread_buf = mean_thread_bufs(i); + auto& var_thread_buf = var_thread_bufs(i); + auto& count_thread_buf = welford_count_thread_bufs(i); + + static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) { + block_sync_lds(); + count_thread_buf(j) = threadwise_welfords(i).cur_count_; + BlockwiseWelford::Run(mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j)); + }); + + if(post_shuffle_thread_cluster_idx[I1] == 0) + { + constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1)); + + constexpr int shuffleMPerBlock = + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength( + I1); + + auto mean_var_count_thread_copy_index = make_multi_index( + block_m_id, // mblock + shuffleMPerBlock * i + post_shuffle_thread_data_idx_begin[I0], // mperblock + block_n_id); // nblock + + auto mean_var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + EDataType, + decltype(thread_welford_desc_I_m_I), + decltype(mean_var_grid_desc_mblock_mperblock_nblock), + tensor_operation::element_wise::PassThrough, + Sequence<1, PostShuffleThreadSliceSize_M, 1>, + Sequence<0, 1, 2>, + 1, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{mean_var_grid_desc_mblock_mperblock_nblock, + mean_var_count_thread_copy_index, + tensor_operation::element_wise::PassThrough{}}; + + mean_var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I, + make_tuple(I0, I0, I0), + mean_thread_buf, + mean_var_grid_desc_mblock_mperblock_nblock, + mean_grid_buf); // write mean + + mean_var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I, + make_tuple(I0, I0, I0), + var_thread_buf, + mean_var_grid_desc_mblock_mperblock_nblock, + var_grid_buf); // write variance + + // Stride of count is [0, 1]. Only the first row in count[0, 0:nblock] need + // to be written. + if(i == 0 && block_m_id == 0 && post_shuffle_thread_cluster_idx[I0] == 0) + { + auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + int32_t, + int32_t, + decltype(thread_welford_desc_I_m_I), + decltype(count_grid_desc_mblock_mperblock_nblock), + tensor_operation::element_wise::PassThrough, + Sequence<1, PostShuffleThreadSliceSize_M, 1>, + Sequence<0, 1, 2>, + 1, + 1, + InMemoryDataOperationEnum::Set, + 1, + false>{count_grid_desc_mblock_mperblock_nblock, + mean_var_count_thread_copy_index, + tensor_operation::element_wise::PassThrough{}}; + + count_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I, + make_tuple(I0, I0, I0), + count_thread_buf, + count_grid_desc_mblock_mperblock_nblock, + welford_count_grid_buf); // write count + } + } + }); + } + + EDataType* p_welford_mean_grid; + EDataType* p_welford_var_grid; + int32_t* p_welford_count_grid; + index_t NRaw; + GemmMeanVarGridDesc_M_N gemm_mean_var_grid_desc_m_nblock; + GemmCountGridDesc_M_N gemm_count_grid_desc_m_nblock; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp new file mode 100644 index 0000000000..ccd999b724 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp" + +namespace ck { + +template +struct EpilogueCShuffle + : EpilogueCShuffleBase +{ + using Base = EpilogueCShuffleBase< + DsDataType, + EDataType, + AccDataType, + CShuffleDataType, + MPerBlock, + NPerBlock, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + CDEElementwiseOperation, + ThisThreadBlock, + BlockwiseGemmPipe>; + + using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat; + using Base::GetCShuffleLDSDescriptor; + using Base::GetVgprToLDSEpilogueDescriptor; + using Base::I1; + using Base::NumDTensor; + + template + __device__ static void Run(CThreadBuf& c_thread_buf, + DsGridPointer p_ds_grid, + EDataType* p_e_grid, + void* p_shared, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + CDEElementwiseOperation& cde_element_op, + const index_t& block_m_id, + const index_t& block_n_id) + { + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + BlockwiseGemmPipe:: + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // LDS buffer + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize()); + + // Thread transfer Vgpr to LDS + auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor(); + + // Space Filling Curve Vgpr + constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{}; + + // Space Filling Curve Vmem + constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{}; + + // Block descriptor + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + GetCShuffleLDSDescriptor(); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // Thread transfer LDS to Vmem + auto cde_shuffle_block_copy_lds_and_global = + Base::template GetLDSToVmemEpilogueDescriptor( + c_ds_desc_refs, + e_grid_desc_mblock_mperblock_nblock_nperblock, + cde_element_op, + block_m_id, + block_n_id); + + // tuple of reference to C/Ds tensor buffers + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!"); + + // CShuffle and Store + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block loads its C data from LDS, D from global, applies elementwise + // operation and stores result E to global + cde_shuffle_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_global_step); + }); + + // move on E + cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step); + } + }); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp new file mode 100644 index 0000000000..d2c6c92c9f --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" + +namespace ck { + +template +struct EpilogueCShuffleBase +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto EShuffleBlockTransferScalarPerVector = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + + using SpaceFillingCurveVgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>; + + using SpaceFillingCurveVmem = SpaceFillingCurve< + Sequence<1, MPerBlock, 1, NPerBlock>, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma>>; + + // *Caution Here repeat is shuffle repeat + __device__ static constexpr auto + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + __device__ static constexpr auto GetCShuffleLDSDescriptor() + { + // C mapping in single block + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + BlockwiseGemmPipe:: + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I1); + constexpr auto MSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I2); + constexpr auto NWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I4); + constexpr auto NThreadPerSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I5); + constexpr auto MAccVgprs = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I6); + + return transform_tensor_descriptor( + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(), + make_tuple(make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{})); + } + + __device__ static auto GetVgprToLDSEpilogueDescriptor() + { + // C mapping in single block + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + BlockwiseGemmPipe:: + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I1); + constexpr auto MSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I2); + constexpr auto NWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I4); + constexpr auto NThreadPerSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I5); + constexpr auto MAccVgprs = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I6); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + BlockwiseGemmPipe::CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor + .CalculateBottomIndex(make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + return ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(BlockwiseGemmPipe:: + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()), + decltype(GetCShuffleLDSDescriptor()), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{GetCShuffleLDSDescriptor(), + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + } + + template + __device__ static auto + GetLDSToVmemEpilogueDescriptor(CDsDescRefs& c_ds_desc_refs, + EGridDesc& e_grid_desc_mblock_mperblock_nblock_nperblock, + CDEElementwiseOperation& cde_element_op, + const index_t& block_m_id, + const index_t& block_n_id) + { + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); }, + Number{})); + + // blockwise copy which loads C from LDS, D from global, applies elementwise + // operation and stores result E to global + return ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, // ThreadGroup + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + CDsDescRefs, + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, // ElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // DstInMemOps, + Sequence<1, + CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * + NPerWmma>, // BlockSliceLengths, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // DstDimAccessOrder, + 3, // SrcVectorDim, + 3, // DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors + EShuffleBlockTransferScalarPerVector, // DstScalarPerVector + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + 1, + Tuple>{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + cde_element_op}; + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index 25653dd859..5d8bbca79d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -315,8 +315,6 @@ struct GridwiseGemm_wmma_cshuffle_v3 using Base::MakeDsGridDescriptor_M_N; using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; - using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat; - using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; using ThisThreadBlock = ThisThreadBlock; @@ -556,7 +554,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 template + TailNumber TailNum, + typename EpilogueArgument> __device__ static void Run(AsGridPointer& p_as_grid, BsGridPointer& p_bs_grid, DsGridPointer& p_ds_grid, @@ -565,7 +564,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) + CDEElementwiseOperation cde_element_op, + EpilogueArgument& epilogue_args) { const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0); @@ -610,6 +610,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock), decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), decltype(b_scale_struct), + decltype(epilogue_args), HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(p_as_grid, @@ -627,16 +628,20 @@ struct GridwiseGemm_wmma_cshuffle_v3 block_m_id, block_n_id, num_k_block_per_scale, - b_scale_struct); + b_scale_struct, + epilogue_args); } // Wrapper function to have __global__ function in common // between gemm_universal, b_scale, ab_scale, etc. template - __device__ static void - Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg) + TailNumber TailNum, + typename EpilogueArgument> + __device__ static void Run(void* p_shared, + const SplitKBatchOffset& splitk_batch_offset, + Argument& karg, + EpilogueArgument& epilogue_args) { // shift A matrices pointer for splitk AsGridPointer p_as_grid_splitk; @@ -663,7 +668,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 karg, karg.a_element_op, karg.b_element_op, - karg.cde_element_op); + karg.cde_element_op, + epilogue_args); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp index 1b8a8ef09e..ca4646a1c1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -209,8 +209,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale using Base::MakeDsGridDescriptor_M_N; using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; - using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat; - using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; using ThisThreadBlock = ThisThreadBlock; @@ -533,7 +531,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale template + TailNumber TailNum, + typename EpilogueArgument> __device__ static void Run(AsGridPointer& p_as_grid, BsGridPointer& p_bs_grid, DsGridPointer& p_ds_grid, @@ -543,7 +542,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) + CDEElementwiseOperation cde_element_op, + EpilogueArgument& epilogue_args) { const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0); @@ -593,6 +593,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock), decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), decltype(b_scale_struct), + decltype(epilogue_args), HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(p_as_grid, @@ -610,16 +611,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale block_m_id, block_n_id, num_k_block_per_scale, - b_scale_struct); + b_scale_struct, + epilogue_args); } // NOTE: Wrapper function to have __global__ function in common // between gemm_universal, b_scale, ab_scale, etc. template - __device__ static void - Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg) + TailNumber TailNum, + typename EpilogueArgument> + __device__ static void Run(void* p_shared, + const SplitKBatchOffset& splitk_batch_offset, + Argument& karg, + EpilogueArgument& epilogue_args) { // shift A matrices pointer for splitk AsGridPointer p_as_grid_splitk; @@ -647,7 +652,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale karg, karg.a_element_op, karg.b_element_op, - karg.cde_element_op); + karg.cde_element_op, + epilogue_args); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 523cb8efd1..7a5e324468 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -23,6 +23,8 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp" namespace ck { @@ -46,12 +48,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + GridwiseGemm::template Run( - p_shared, splitk_batch_offset, karg); + p_shared, splitk_batch_offset, karg, epilogue_args); #if defined(__gfx11__) } @@ -262,9 +268,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base static_assert(!PermuteA, "PermuteA is not supported"); // return block_id to C matrix tile idx (m0, n0) mapping - // if arch = gfx942 using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; - // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { @@ -539,23 +543,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base Number{}); } - __host__ __device__ static constexpr auto - // *Caution Here repeat is shuffle repeat - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() - { - constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); - constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); - - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; - } - using BlockwiseGemmPipe = remove_cvref_t())>; + // Used to create obj in global function and pass it to Run method + using EpilogueCShuffle = + EpilogueCShuffle; + + using EpilogueWelfordCShuffle = EpilogueWelfordCShuffle< + DsDataType, + EDataType, + AccDataType, + CShuffleDataType, + MPerBlock, + NPerBlock, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + CDEElementwiseOperation, + ThisThreadBlock, + BlockwiseGemmPipe, + BlockSize>; + template __device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock) @@ -821,6 +848,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); } + template __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { // LDS allocation for A and B: be careful of alignment @@ -838,7 +866,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // LDS allocation for C shuffle in LDS constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + EpilogueType:: + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); constexpr auto c_block_size = c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat @@ -867,6 +896,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename BScaleStruct, + typename EpilogueArgument, bool HasMainKBlockLoop, InMemoryDataOperationEnum EGlobalMemoryDataOperation, TailNumber TailNum = TailNumber::Odd> @@ -887,7 +917,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base const index_t& block_m_id, const index_t& block_n_id, const index_t& num_k_block_per_scale, - BScaleStruct& b_scale_struct) + BScaleStruct& b_scale_struct, + EpilogueArgument& epilogue_args) { const auto as_grid_buf = generate_tuple( [&](auto i) { @@ -903,16 +934,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base }, Number{}); - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], - ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); - }, - Number{}); - auto e_grid_buf = make_dynamic_buffer( - p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // lds max alignment constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); @@ -984,240 +1005,16 @@ struct GridwiseGemm_wmma_cshuffle_v3_base num_k_block_per_scale); // shuffle C and write out - { - // C mapping in single thread. - constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = - blockwise_gemm_pipeline - .GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); - - // C mapping in single block - constexpr auto - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = - blockwise_gemm_pipeline - .GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); - - constexpr auto MWave = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I1); - constexpr auto MSubGroup = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I2); - constexpr auto NWave = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I4); - constexpr auto NThreadPerSubGroup = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I5); - constexpr auto MAccVgprs = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I6); - - // LDS descriptor, shuffle and write out in MRepeat x NRepeat times - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat - .GetElementSpaceSize()); - - constexpr auto - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = - transform_tensor_descriptor( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // MRepeat per shuffle repeat - MWave, // MWave - MSubGroup, // MSubGroup * MAccVgprs = MPerWmma - MAccVgprs)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // NRepeat per shuffle repeat - NWave, // NWave - NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 1, 2, 6>{}, - Sequence<>{}, - Sequence<3, 4, 5>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( - MRepeat, MWave, MSubGroup, MAccVgprs))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor - .CalculateBottomIndex(make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( - NRepeat, NWave, NThreadPerSubGroup))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor - .CalculateBottomIndex(make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), - decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6>, - 6, - 1, // vector write pixel - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, - make_multi_index(0, - m_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - 0, - n_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor buffers - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); }, - Number{})); - - // blockwise copy which loads C from LDS, D from global, applies elementwise - // operation and stores result E to global - auto cde_shuffle_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< - ThisThreadBlock, // ThreadGroup - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CDEElementwiseOperation, // ElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // DstInMemOps, - Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerWmma, - 1, - CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, - CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // DstDimAccessOrder, - 3, // SrcVectorDim, - 3, // DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors - EShuffleBlockTransferScalarPerVector, // DstScalarPerVector - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), - cde_element_op}; - - // space filling curve for local reg & global memory - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_cde_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerWmma, - 1, - CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run( - c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block loads its C data from LDS, D from global, applies elementwise - // operation and stores result E to global - cde_shuffle_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(e_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_global_step); - }); - - // move on E - cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step); - } - }); - } + epilogue_args.template Run( + c_thread_buf, + p_ds_grid, + p_e_grid, + p_shared, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + cde_element_op, + block_m_id, + block_n_id); } }; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp index 5682117f76..7e9870bf91 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp @@ -43,7 +43,8 @@ template typename DstResetCoordinateAfterRunFlags, // Sequence - index_t NumThreadScratch = 1> + index_t NumThreadScratch = 1, + typename InterDatas = DstDatas> struct ThreadwiseTensorSliceTransfer_v7r3 { static constexpr auto I0 = Number<0>{}; @@ -153,7 +154,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3 // loop over space-filling curve static_for<0, src_num_access, 1>{}([&](auto iAccess) { auto src_vectors = generate_vectors(); - auto elm_vectors = generate_vectors(); + auto elm_vectors = generate_vectors(); bool oob_val = true; @@ -226,9 +227,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3 auto dst_data_refs = generate_tie( // return type should be lvalue [&](auto iDst) -> auto& { - using DstData = remove_cvref_t>; + using InterData = remove_cvref_t>; - using elem_op_vec_t = typename vector_type::type; + using elem_op_vec_t = + typename vector_type::type; return elm_vectors(iDst).template AsType()(i); }, @@ -297,17 +299,17 @@ struct ThreadwiseTensorSliceTransfer_v7r3 __device__ void TransposeFromElmToDst(Number thread_scratch_id = Number{}) { - using DstData = remove_cvref_t; + using InterData = remove_cvref_t; using ElmThreadScratch = StaticTensorTupleOfVectorBuffer; using DstThreadScratch = StaticTensorTupleOfVectorBuffer; @@ -319,11 +321,11 @@ struct ThreadwiseTensorSliceTransfer_v7r3 bit_cast(elm_vectors_tuple_[thread_scratch_id]); if constexpr(SrcVectorDim != DstVectorDim && - ((is_same>::value && + ((is_same>::value && SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || - (is_same>::value && + (is_same>::value && SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) || - (is_same>::value && + (is_same>::value && SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) { // each transpose does @@ -356,8 +358,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3 constexpr auto data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); - using src_vector_t = vector_type_maker_t; - using dst_vector_t = vector_type_maker_t; + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; // get DstScalarPerVector # of read-only references to src vectors from // src_thread_scratch_ @@ -380,7 +382,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3 Number{}); // do data transpose - transpose_vectors{}( + transpose_vectors{}( src_vector_refs, dst_vector_refs); }); } @@ -393,6 +395,104 @@ struct ThreadwiseTensorSliceTransfer_v7r3 dst_vectors_tuple_(thread_scratch_id) = bit_cast(dst_thread_scratch_.data_); } + // DstDescs: Tuple + // DstBuffers: Tuple + // DstVgprDescs: Tuple + // DstVgprBuffers: Tuple + template = false> + __device__ void + RunWriteAndStoreVgpr(const DstDescs& dst_descs, + DstBuffers dst_bufs, + const DstVgprDescs&, + DstVgprBuffers dst_vgpr_buf, + Number thread_scratch_id = Number{}) + { + // Same functionality of RunWrite but additionally store internal Vgpr in dst_vgpr_buf + OOBCheck(thread_scratch_id); + TransposeFromElmToDst(thread_scratch_id); + + // Vgpr buffer origin is set internally to 0 + constexpr auto dst_slice_origin_idx = + generate_tuple([&](auto) { return I0; }, Number{}); + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + // loop over space-filling curve + static_for<0, dst_num_access, 1>{}([&](auto iAccess) { + auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; + + static_for<0, nDst, 1>{}([&](auto i) { + // copy data from buf_vectors into dst_bufs + using DstData = remove_cvref_t; + using InterData = remove_cvref_t; + + typename vector_type_maker::type dst_vector; + using dst_vector_t = + typename vector_type_maker::type::type; + + static_for<0, DstScalarPerVector, 1>{}([&](auto j) { + dst_vector.template AsType()(j) = + type_convert(dst_vectors[i].template AsType()[j]); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + dst_bufs(i).template Update( + dst_coords_[i].GetOffset(), + is_dst_valid, + dst_vector.template AsType()[I0]); + + // store Vgpr + using DstVgprDesc = remove_cvref_t; + static_assert(DstVgprDesc::IsKnownAtCompileTime(), + "wrong! DstDesc need to known at compile-time"); + constexpr auto dst_vgpr_desc = DstVgprDesc{}; + + constexpr auto src_data_idx = DstSpaceFillingCurve::GetIndex(iAccess); + static_for<0, DstScalarPerVector, 1>{}([&](auto j) { + constexpr index_t dst_offset = + dst_vgpr_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + + src_data_idx + j * dst_scalar_step_in_vector); + + dst_vgpr_buf(I0)(Number{}) = + is_dst_valid ? dst_vectors[i].template AsType()[j] + : NumericLimits::QuietNaN(); + }); + }); + + // move coordinate + if constexpr(iAccess.value != dst_num_access - 1) + { + constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step)); + }); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + // DstDescs: Tuple // DstBuffers: Tuple template thread_scratch_id = Number{}) { + static_assert(is_same_v, + "RunWrite doesn't support inter data type different from dst data type"); + OOBCheck(thread_scratch_id); TransposeFromElmToDst(thread_scratch_id); @@ -630,8 +733,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3 private: using SrcVectorsType = decltype(generate_vectors()); - using ElmVectorsType = decltype(generate_vectors()); - using DstVectorsType = decltype(generate_vectors()); + using ElmVectorsType = decltype(generate_vectors()); + using DstVectorsType = decltype(generate_vectors()); static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess(); static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess(); diff --git a/include/ck_tile/host/tensor_shuffle_utils.hpp b/include/ck_tile/host/tensor_shuffle_utils.hpp old mode 100755 new mode 100644 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp index dd8ecae62c..b252186b10 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -15,6 +15,7 @@ namespace tensor_operation { namespace device { namespace instance { +#if defined(CK_USE_XDL) void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances( std::vector>>&); +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) +void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances( + std::vector>>&); +#endif // GEMM + Add + Relu + Add + Layernorm template && is_same_v && is_same_v) { +#if defined(CK_USE_XDL) add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances( op_ptrs); +#endif +#if defined(CK_USE_WMMA) + add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances( + op_ptrs); +#endif } else if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v) { +#if defined(CK_USE_XDL) add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances( op_ptrs); +#endif +#if defined(CK_USE_WMMA) + add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances( + op_ptrs); +#endif } else if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v) { +#if defined(CK_USE_XDL) add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances( op_ptrs); +#endif +#if defined(CK_USE_WMMA) + add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances( + op_ptrs); +#endif } else if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v) { +#if defined(CK_USE_XDL) add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances( op_ptrs); +#endif +#if defined(CK_USE_WMMA) + add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances( + op_ptrs); +#endif } } diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt index b9aeb6a6db..87b414faf7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt @@ -1,7 +1,12 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_add_relu_add_layernorm_instance device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp + + device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp + device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp + device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp + device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..2f0f28b5e1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// h = layernorm(e, gamma, beta) +// output: h[m, n] +// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n] +template +using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances = std::tuple< + // clang-format off + //##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline| + //##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | | + //##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | | + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 2, 2, 16, 16, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 2, 2, 16, 16, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline> + // clang-format on + >; + +// irregular tile size +template +using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline| + //##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | | + //##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | | + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 64, 32, 32, 32, 8, 8, 16, 16, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline> + // clang-format on + >; + +void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances< + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); + add_device_operation_instances( + instances, + device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_irregular_tile_instances< + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..84c4dac078 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// h = layernorm(e, gamma, beta) +// output: h[m, n] +// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n] +template +using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances = std::tuple< + // clang-format off + //##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline| + //##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | | + //##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | | + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 4, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 8, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 2, 2, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 4, 2, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 2, 2, 16, 16, 4, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 2, 2, 16, 16, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline> + // clang-format on + >; + +// irregular tile size +template +using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline| + //##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | | + //##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | | + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 64, 32, 32, 32, 8, 8, 16, 16, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline> + // clang-format on + >; + +void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances< + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); + add_device_operation_instances( + instances, + device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_irregular_tile_instances< + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..1153df40d4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// h = layernorm(e, gamma, beta) +// output: h[m, n] +// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n] +template +using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances = std::tuple< + // clang-format off + //##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline| + //##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | | + //##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | | + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 2, 2, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 2, 2, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline> + // clang-format on + >; + +// irregular tile size +template +using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline| + //##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | | + //##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | | + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 64, 32, 32, 32, 8, 8, 16, 16, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline> + // clang-format on + >; + +void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances< + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); + add_device_operation_instances( + instances, + device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_irregular_tile_instances< + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..29972a9010 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// h = layernorm(e, gamma, beta) +// output: h[m, n] +// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n] +template +using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances = std::tuple< + // clang-format off + //##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline| + //##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | | + //##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | | + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline> + // clang-format on + >; + +template +using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline| + //##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | | + //##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | | + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 64, 32, 32, 32, 8, 8, 16, 16, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline> + // clang-format on + >; + +void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances< + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); + + add_device_operation_instances( + instances, + device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_irregular_tile_instances< + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp b/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp index a8daf4e787..99076a20ec 100644 --- a/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -167,6 +167,12 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification, Tensor h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{})); Tensor h_m_n_host(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{})); + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "h_m_n: " << h_m_n.mDesc << std::endl; + switch(init_method) { case 0: break; @@ -312,9 +318,8 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification, float gb_per_sec = num_byte / 1.E6 / ave_time; - if(time_kernel) - std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec - << " GB/s, " << op_name << std::endl; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; if(ave_time < best_ave_time) { @@ -333,8 +338,7 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification, } else { - if(time_kernel) - std::cout << op_name << " does not support this problem" << std::endl; + std::cout << op_name << " does not support this problem" << std::endl; } } diff --git a/test/gemm_layernorm/CMakeLists.txt b/test/gemm_layernorm/CMakeLists.txt index d1102a561a..d912ce301c 100644 --- a/test/gemm_layernorm/CMakeLists.txt +++ b/test/gemm_layernorm/CMakeLists.txt @@ -1,6 +1,8 @@ -add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16_xdl.cpp) -if(result EQUAL 0) - add_custom_target(test_gemm_layernorm) - target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) - add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") + add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp) + if(result EQUAL 0) + add_custom_target(test_gemm_layernorm) + target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) + add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) + endif() endif() diff --git a/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16_xdl.cpp b/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp similarity index 96% rename from test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16_xdl.cpp rename to test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp index ae872d3133..93142155d5 100644 --- a/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16_xdl.cpp +++ b/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp @@ -79,11 +79,6 @@ TYPED_TEST_SUITE(TestGemmAddReluAddLayernorm, KernelTypes); TYPED_TEST(TestGemmAddReluAddLayernorm, Test_FP16) { this->Run(); } int main(int argc, char** argv) { - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) - { - std::cout << "No available instance for gfx11 & gfx12." << std::endl; - return 0; - } testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); }