mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
WMMA gemm_add_relu_add_layernorm (#2989)
* Summary:
- Refactor epilogue (with CShuffle) to support fused operations:
- EpilogueCShuffleBase holds common parts
- EpilogueCShuffle: runs CShuffle and write out
- EpilogueWelfordCShuffle: holds Welford specific arguments, runs CShuffle, write out, Welford first part and Welford write out
- Extend thread transfer v7r3:
- Support for intermediate data type different from src and dst type
- New functionality to write to dst buffer and keep data (to be able to use them for additional operations)
* Adress review comments
This commit is contained in:
@@ -292,13 +292,15 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
|
||||
}
|
||||
|
||||
static constexpr auto MAccVgprs =
|
||||
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()[I2];
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
|
||||
{
|
||||
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
|
||||
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
|
||||
|
||||
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
|
||||
constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
|
||||
return make_naive_tensor_descriptor(
|
||||
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
|
||||
|
||||
@@ -42,7 +42,8 @@ template <typename ThreadGroup,
|
||||
index_t DstScalarPerVector,
|
||||
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
typename ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
index_t NumThreadScratch = 1>
|
||||
index_t NumThreadScratch = 1,
|
||||
typename InterDatas = DstDatas>
|
||||
struct ThreadGroupTensorSliceTransfer_v7r3
|
||||
{
|
||||
static constexpr index_t nDim =
|
||||
@@ -97,7 +98,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
@@ -123,7 +124,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
|
||||
const SrcBuffers& src_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
|
||||
@@ -138,7 +139,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
|
||||
DstBuffers dst_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
|
||||
@@ -148,6 +149,36 @@ struct ThreadGroupTensorSliceTransfer_v7r3
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffers,
|
||||
typename DstVgprDescs,
|
||||
typename DstVgprBuffers,
|
||||
index_t ThreadScratchId = 0>
|
||||
__device__ void
|
||||
RunWriteAndStoreVgpr(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
const DstVgprDescs& dst_vgpr_desc,
|
||||
DstVgprBuffers dst_vgpr_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value &&
|
||||
is_detected<is_tuple, decltype(dst_vgpr_buf)>::value)
|
||||
threadwise_transfer_.RunWriteAndStoreVgpr(
|
||||
dst_descs, dst_bufs, dst_vgpr_desc, dst_vgpr_buf, thread_scratch_id);
|
||||
else if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
|
||||
threadwise_transfer_.RunWriteAndStoreVgpr(
|
||||
dst_descs, dst_bufs, dst_vgpr_desc, tie(dst_vgpr_buf), thread_scratch_id);
|
||||
else if constexpr(is_detected<is_tuple, decltype(dst_vgpr_buf)>::value)
|
||||
threadwise_transfer_.RunWriteAndStoreVgpr(
|
||||
dst_descs, tie(dst_bufs), dst_vgpr_desc, dst_vgpr_buf, thread_scratch_id);
|
||||
else
|
||||
threadwise_transfer_.RunWriteAndStoreVgpr(
|
||||
dst_descs, tie(dst_bufs), dst_vgpr_desc, tie(dst_vgpr_buf), thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffers, typename DstBuffers>
|
||||
__device__ void Run(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
@@ -162,7 +193,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
|
||||
@@ -179,7 +210,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
|
||||
__device__ void
|
||||
MoveDstSliceWindow(const DstDescs& dst_descs, Number<IDst> iDst, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
|
||||
@@ -212,7 +243,8 @@ struct ThreadGroupTensorSliceTransfer_v7r3
|
||||
DstScalarPerVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
NumThreadScratch>;
|
||||
NumThreadScratch,
|
||||
InterDatas>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
@@ -60,7 +60,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const long_index_t c_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
@@ -82,6 +84,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
|
||||
});
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_shift,
|
||||
p_bs_grid_shift,
|
||||
@@ -91,7 +95,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -46,12 +46,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
std::is_same_v<c_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
// The normal approach to batching would be to increase the grid size by just stretching out
|
||||
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
|
||||
// functions not directly using the Z dimension for other calculations. As it turns out, k
|
||||
// batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
|
||||
// we will use the grid Y dimension for batching. This may be a bit fragile.
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
|
||||
|
||||
@@ -84,6 +86,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
|
||||
});
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_shift,
|
||||
p_bs_grid_shift,
|
||||
@@ -94,7 +98,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,896 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename EMeanVarDataType,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_gemm_multiple_d_welford_first_half_wmma_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument karg,
|
||||
EMeanVarDataType* __restrict__ p_welford_mean_grid,
|
||||
EMeanVarDataType* __restrict__ p_welford_var_grid,
|
||||
int32_t* __restrict__ p_welford_count_grid)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<e_data_type, ck::half_t> ||
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueWelfordCShuffle>();
|
||||
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueWelfordCShuffle(
|
||||
p_welford_mean_grid, p_welford_var_grid, p_welford_count_grid, karg.M, karg.N);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = p_welford_mean_grid;
|
||||
ignore = p_welford_var_grid;
|
||||
ignore = p_welford_count_grid;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename GridwiseWelfordLayernorm,
|
||||
typename EMeanVarDataType,
|
||||
typename HDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename ComputeDataType,
|
||||
typename EHGridDesc_M_N,
|
||||
typename LayernormMeanVarGridDesc_M_NBlock,
|
||||
typename LayernormCountGridDesc_M_NBlock,
|
||||
typename GammaBetaGridDesc_N,
|
||||
typename HElementwiseOperation>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_welford_layernorm2d_second_half(
|
||||
const EMeanVarDataType* __restrict__ p_e_grid,
|
||||
const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
|
||||
const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
|
||||
const int32_t* __restrict__ p_in_welford_count_grid,
|
||||
const GammaDataType* __restrict__ p_gamma_grid,
|
||||
const BetaDataType* __restrict__ p_beta_grid,
|
||||
HDataType* __restrict__ p_h_grid,
|
||||
const EHGridDesc_M_N e_grid_desc_m_n,
|
||||
const EHGridDesc_M_N h_grid_desc_m_n,
|
||||
const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock,
|
||||
const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock,
|
||||
const GammaBetaGridDesc_N gamma_grid_desc_n,
|
||||
const GammaBetaGridDesc_N beta_grid_desc_n,
|
||||
index_t numMeanVarCountBlockTileIteration_N,
|
||||
index_t NBlockClusterLength,
|
||||
ComputeDataType epsilon,
|
||||
HElementwiseOperation h_element_op)
|
||||
{
|
||||
GridwiseWelfordLayernorm::Run(p_e_grid,
|
||||
p_in_welford_mean_grid,
|
||||
p_in_welford_var_grid,
|
||||
p_in_welford_count_grid,
|
||||
p_gamma_grid,
|
||||
p_beta_grid,
|
||||
p_h_grid,
|
||||
e_grid_desc_m_n,
|
||||
h_grid_desc_m_n,
|
||||
mean_var_grid_desc_m_nblock,
|
||||
count_grid_desc_m_nblock,
|
||||
gamma_grid_desc_n,
|
||||
beta_grid_desc_n,
|
||||
numMeanVarCountBlockTileIteration_N,
|
||||
NBlockClusterLength,
|
||||
epsilon,
|
||||
h_element_op);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename HLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename HDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename EMeanVarDataType, // LayerNorm
|
||||
typename GammaDataType, // LayerNorm
|
||||
typename BetaDataType, // LayerNorm
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename HElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector,
|
||||
typename LayernormThreadClusterSize_M_N,
|
||||
index_t LayernormThreadSliceSize_M,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = HDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
|
||||
: public DeviceGemmMultipleDLayernorm<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
HLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
HDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
HElementwiseOperation>
|
||||
{
|
||||
// EDataType, MeanDataType and VarDataType must be the same.
|
||||
using DeviceOp = DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr index_t LayernormHDstVectorSize = CDEShuffleBlockTransferScalarPerVector;
|
||||
static constexpr index_t LayernormGammaSrcVectorSize = CDEShuffleBlockTransferScalarPerVector;
|
||||
static constexpr index_t LayernormBetaSrcVectorSize = CDEShuffleBlockTransferScalarPerVector;
|
||||
static constexpr index_t LayernormESrcVectorSize = CDEShuffleBlockTransferScalarPerVector;
|
||||
static constexpr index_t LayernormThreadSliceSize_N = CDEShuffleBlockTransferScalarPerVector;
|
||||
|
||||
using LayernormBlockTileSize_M_N =
|
||||
Sequence<LayernormThreadClusterSize_M_N::At(0) * LayernormThreadSliceSize_M,
|
||||
LayernormThreadClusterSize_M_N::At(1) * LayernormThreadSliceSize_N>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using CDEShuffleBlockTransferScalarPerVectors =
|
||||
Sequence<CDEShuffleBlockTransferScalarPerVector,
|
||||
CDEShuffleBlockTransferScalarPerVector,
|
||||
CDEShuffleBlockTransferScalarPerVector>;
|
||||
|
||||
// GEMM + Welford 1st part kernel
|
||||
using GridwiseGemmWelford = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
HLayout,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EMeanVarDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
|
||||
// Welford 2nd part kernel
|
||||
template <typename DoPads, index_t MPerTile, index_t NPerTile>
|
||||
static auto MakeEHGridDescriptor_M_N(index_t M, index_t N, index_t Stride)
|
||||
{
|
||||
// Only support row major for E and H
|
||||
const auto grid_desc_m_n =
|
||||
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(Stride, I1));
|
||||
return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
|
||||
}
|
||||
|
||||
template <index_t XPerTile>
|
||||
static auto MakeDescriptor_X(index_t X)
|
||||
{
|
||||
const auto grid_desc_x = make_naive_tensor_descriptor_packed(make_tuple(X));
|
||||
return PadTensorDescriptor(grid_desc_x, make_tuple(XPerTile), Sequence<true>{});
|
||||
}
|
||||
|
||||
using LayernormMeanVarGridDesc_M_NBlock =
|
||||
decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N<
|
||||
Sequence<true, true>,
|
||||
LayernormBlockTileSize_M_N::At(0),
|
||||
LayernormBlockTileSize_M_N::At(1)>(1, 1));
|
||||
|
||||
using LayernormCountGridDesc_M_NBlock =
|
||||
decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N<
|
||||
Sequence<true, true>,
|
||||
LayernormBlockTileSize_M_N::At(0),
|
||||
LayernormBlockTileSize_M_N::At(1)>(1, 1));
|
||||
|
||||
using GammaBetaGridDesc_N = decltype(MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(1));
|
||||
using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N<Sequence<true, true>, 1, 1>(1, 1, 1));
|
||||
|
||||
using GridwiseWelfordLayernorm =
|
||||
GridwiseWelfordSecondHalfLayernorm2d<EMeanVarDataType,
|
||||
HDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
EHGridDesc_M_N,
|
||||
LayernormMeanVarGridDesc_M_NBlock,
|
||||
LayernormCountGridDesc_M_NBlock,
|
||||
GammaBetaGridDesc_N,
|
||||
HElementwiseOperation,
|
||||
BlockSize,
|
||||
LayernormThreadClusterSize_M_N::At(I0),
|
||||
LayernormThreadClusterSize_M_N::At(I1),
|
||||
LayernormThreadSliceSize_M,
|
||||
LayernormThreadSliceSize_N,
|
||||
LayernormESrcVectorSize,
|
||||
LayernormHDstVectorSize,
|
||||
LayernormGammaSrcVectorSize,
|
||||
LayernormBetaSrcVectorSize>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_a_grid,
|
||||
const void* p_b_grid,
|
||||
std::array<const void*, NumDTensor> p_ds_grid,
|
||||
const void* p_gamma_grid,
|
||||
const void* p_beta_grid,
|
||||
void* p_h_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideH,
|
||||
double epsilon,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
HElementwiseOperation h_element_op)
|
||||
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_ds_grid_{},
|
||||
p_workspace_e_grid_{nullptr},
|
||||
p_workspace_mean_{nullptr},
|
||||
p_workspace_var_{nullptr},
|
||||
p_workspace_count_{nullptr},
|
||||
p_gamma_grid_{static_cast<const GammaDataType*>(p_gamma_grid)},
|
||||
p_beta_grid_{static_cast<const BetaDataType*>(p_beta_grid)},
|
||||
p_h_grid_{static_cast<HDataType*>(p_h_grid)},
|
||||
layernorm_e_grid_desc_m_n_{
|
||||
DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>,
|
||||
LayernormBlockTileSize_M_N::At(0),
|
||||
LayernormBlockTileSize_M_N::At(1)>(
|
||||
MRaw, NRaw, StrideH)},
|
||||
layernorm_mean_var_grid_desc_m_nblock_{},
|
||||
layernorm_count_grid_desc_m_nblock_{},
|
||||
gamma_grid_desc_n_{
|
||||
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
|
||||
beta_grid_desc_n_{
|
||||
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
|
||||
h_grid_desc_m_n_{
|
||||
DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>,
|
||||
LayernormBlockTileSize_M_N::At(0),
|
||||
LayernormBlockTileSize_M_N::At(1)>(
|
||||
MRaw, NRaw, StrideH)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
h_element_op_{h_element_op},
|
||||
MRaw_{MRaw},
|
||||
NRaw_{NRaw},
|
||||
KRaw_{KRaw},
|
||||
StrideA_{StrideA},
|
||||
StrideB_{StrideB},
|
||||
StrideDs_{StrideDs},
|
||||
StrideH_{StrideH},
|
||||
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
|
||||
epsilon_{static_cast<AccDataType>(epsilon)}
|
||||
{
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) { p_ds_grid_[i] = p_ds_grid[i]; });
|
||||
|
||||
layernorm_mean_var_grid_desc_m_nblock_ =
|
||||
GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N<
|
||||
Sequence<true, true>,
|
||||
LayernormBlockTileSize_M_N::At(0),
|
||||
LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_);
|
||||
|
||||
layernorm_count_grid_desc_m_nblock_ =
|
||||
GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N<
|
||||
Sequence<true, true>,
|
||||
LayernormBlockTileSize_M_N::At(0),
|
||||
LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_);
|
||||
}
|
||||
|
||||
// pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
std::array<const void*, NumDTensor> p_ds_grid_;
|
||||
void* p_workspace_e_grid_;
|
||||
void* p_workspace_mean_;
|
||||
void* p_workspace_var_;
|
||||
void* p_workspace_count_;
|
||||
const GammaDataType* p_gamma_grid_;
|
||||
const BetaDataType* p_beta_grid_;
|
||||
HDataType* p_h_grid_;
|
||||
|
||||
// tensor descriptors (Welford second half)
|
||||
EHGridDesc_M_N layernorm_e_grid_desc_m_n_;
|
||||
LayernormMeanVarGridDesc_M_NBlock layernorm_mean_var_grid_desc_m_nblock_;
|
||||
LayernormCountGridDesc_M_NBlock layernorm_count_grid_desc_m_nblock_;
|
||||
GammaBetaGridDesc_N gamma_grid_desc_n_;
|
||||
GammaBetaGridDesc_N beta_grid_desc_n_;
|
||||
EHGridDesc_M_N h_grid_desc_m_n_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
HElementwiseOperation h_element_op_;
|
||||
|
||||
index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
index_t KRaw_;
|
||||
index_t StrideA_;
|
||||
index_t StrideB_;
|
||||
std::array<index_t, NumDTensor> StrideDs_;
|
||||
index_t StrideH_;
|
||||
index_t gemm_nblock_;
|
||||
AccDataType epsilon_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
typename GridwiseGemmWelford::Argument gemm_arg{
|
||||
std::array<const void*, 1>{arg.p_a_grid_},
|
||||
std::array<const void*, 1>{arg.p_b_grid_},
|
||||
arg.p_ds_grid_,
|
||||
static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
|
||||
arg.MRaw_,
|
||||
arg.NRaw_,
|
||||
arg.KRaw_,
|
||||
std::array<index_t, 1>{arg.StrideA_}, // StrideAs
|
||||
std::array<index_t, 1>{arg.StrideB_}, // StrideBs
|
||||
arg.StrideDs_, // StrideDs
|
||||
arg.StrideH_, // StrideE
|
||||
I1, // kbatch
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_};
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
gemm_arg.Print();
|
||||
GridwiseGemmWelford::BlockwiseGemmPipe::HotLoopInstList::Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemmWelford::CheckValidity(gemm_arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
|
||||
}
|
||||
|
||||
if(arg.p_workspace_e_grid_ == nullptr || arg.p_workspace_mean_ == nullptr ||
|
||||
arg.p_workspace_var_ == nullptr || arg.p_workspace_count_ == nullptr)
|
||||
throw std::runtime_error("wrong! WorkSpace pointer has not been set");
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
GridwiseGemmWelford::CalculateGridSize(arg.MRaw_, arg.NRaw_, 1);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t K_split = (arg.KRaw_ + KPerBlock - 1) / KPerBlock * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop =
|
||||
GridwiseGemmWelford::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel_gemm_welford_first_half) {
|
||||
// Note: cache flushing not supported
|
||||
|
||||
const auto kernel_welford_second_half =
|
||||
kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
|
||||
EMeanVarDataType,
|
||||
HDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
EHGridDesc_M_N,
|
||||
LayernormMeanVarGridDesc_M_NBlock,
|
||||
LayernormCountGridDesc_M_NBlock,
|
||||
GammaBetaGridDesc_N,
|
||||
HElementwiseOperation>;
|
||||
|
||||
// First kernel launch: GEMM + Welford first part
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel_gemm_welford_first_half,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg,
|
||||
static_cast<EMeanVarDataType*>(arg.p_workspace_mean_),
|
||||
static_cast<EMeanVarDataType*>(arg.p_workspace_var_),
|
||||
static_cast<int32_t*>(arg.p_workspace_count_));
|
||||
|
||||
// Second kernel launch: Welford second part
|
||||
const auto M = arg.h_grid_desc_m_n_.GetLength(I0);
|
||||
const auto N = arg.h_grid_desc_m_n_.GetLength(I1);
|
||||
|
||||
index_t MBlockClusterLength =
|
||||
math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0));
|
||||
index_t NBlockClusterLength =
|
||||
math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(1));
|
||||
|
||||
auto grid_size = MBlockClusterLength * NBlockClusterLength;
|
||||
|
||||
index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil(
|
||||
arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1));
|
||||
|
||||
ave_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel_welford_second_half,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
|
||||
static_cast<const EMeanVarDataType*>(arg.p_workspace_mean_),
|
||||
static_cast<const EMeanVarDataType*>(arg.p_workspace_var_),
|
||||
static_cast<const int32_t*>(arg.p_workspace_count_),
|
||||
arg.p_gamma_grid_,
|
||||
arg.p_beta_grid_,
|
||||
arg.p_h_grid_,
|
||||
arg.layernorm_e_grid_desc_m_n_,
|
||||
arg.h_grid_desc_m_n_,
|
||||
arg.layernorm_mean_var_grid_desc_m_nblock_,
|
||||
arg.layernorm_count_grid_desc_m_nblock_,
|
||||
arg.gamma_grid_desc_n_,
|
||||
arg.beta_grid_desc_n_,
|
||||
numMeanVarCountBlockTileIteration_N,
|
||||
NBlockClusterLength,
|
||||
arg.epsilon_,
|
||||
arg.h_element_op_);
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy = []() {
|
||||
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
const auto kernel = kernel_gemm_multiple_d_welford_first_half_wmma_cshuffle_v3<
|
||||
GridwiseGemmWelford,
|
||||
EMeanVarDataType,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
const auto kernel = kernel_gemm_multiple_d_welford_first_half_wmma_cshuffle_v3<
|
||||
GridwiseGemmWelford,
|
||||
EMeanVarDataType,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
|
||||
{
|
||||
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
|
||||
|
||||
size_t workspace_size = 0;
|
||||
|
||||
int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
|
||||
|
||||
// workspace for welford intermediate mean
|
||||
workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 128;
|
||||
|
||||
// workspace for welford intermediate variance
|
||||
workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 128;
|
||||
|
||||
// workspace for welford intermediate count
|
||||
workspace_size += pArg_->gemm_nblock_ * sizeof(int32_t) + 128;
|
||||
|
||||
if constexpr(!is_same_v<EMeanVarDataType, HDataType>)
|
||||
workspace_size += pArg_->MRaw_ * pArg_->NRaw_ * sizeof(EMeanVarDataType);
|
||||
|
||||
return (workspace_size);
|
||||
};
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* pArg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
|
||||
|
||||
pArg_->p_workspace_ = p_workspace;
|
||||
|
||||
int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
|
||||
|
||||
// setup buffer used for intermediate welford mean
|
||||
pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
|
||||
|
||||
index_t mean_space_sz = gemm_welford_size * sizeof(EMeanVarDataType);
|
||||
mean_space_sz = math::integer_least_multiple(mean_space_sz, 128);
|
||||
|
||||
// setup buffer used for intermediate welford variance
|
||||
pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz;
|
||||
|
||||
index_t variance_space_sz = gemm_welford_size * sizeof(EMeanVarDataType);
|
||||
variance_space_sz = math::integer_least_multiple(variance_space_sz, 128);
|
||||
|
||||
// setup buffer used for intermediate welford count
|
||||
pArg_->p_workspace_count_ =
|
||||
reinterpret_cast<char*>(pArg_->p_workspace_var_) + variance_space_sz;
|
||||
|
||||
index_t count_space_sz = gemm_welford_size * sizeof(int32_t);
|
||||
count_space_sz = math::integer_least_multiple(count_space_sz, 128);
|
||||
|
||||
if constexpr(!is_same_v<EMeanVarDataType, HDataType>)
|
||||
pArg_->p_workspace_e_grid_ =
|
||||
reinterpret_cast<char*>(pArg_->p_workspace_count_) + count_space_sz;
|
||||
else
|
||||
pArg_->p_workspace_e_grid_ = static_cast<void*>(pArg_->p_h_grid_);
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// No need to check for splitK because we force KBatch = 1 (no support)
|
||||
|
||||
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
|
||||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if((arg.KRaw_ % AK1 != 0 || arg.KRaw_ % BK1 != 0) &&
|
||||
!(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
typename GridwiseGemmWelford::Argument gemm_arg{
|
||||
std::array<const void*, 1>{arg.p_a_grid_},
|
||||
std::array<const void*, 1>{arg.p_b_grid_},
|
||||
arg.p_ds_grid_,
|
||||
static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
|
||||
arg.MRaw_,
|
||||
arg.NRaw_,
|
||||
arg.KRaw_,
|
||||
std::array<index_t, 1>{arg.StrideA_}, // StrideAs
|
||||
std::array<index_t, 1>{arg.StrideB_}, // StrideBs
|
||||
arg.StrideDs_, // StrideDs
|
||||
arg.StrideH_, // StrideE
|
||||
I1, // kbatch
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_};
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
GridwiseGemmWelford::MakeAsGridDescriptor_AK0_M_AK1(gemm_arg.M,
|
||||
gemm_arg.MPadded,
|
||||
gemm_arg.K,
|
||||
gemm_arg.KPadded,
|
||||
gemm_arg.StrideAs,
|
||||
gemm_arg.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
GridwiseGemmWelford::MakeBsGridDescriptor_BK0_N_BK1(gemm_arg.K,
|
||||
gemm_arg.KPadded,
|
||||
gemm_arg.N,
|
||||
gemm_arg.NPadded,
|
||||
gemm_arg.StrideBs,
|
||||
gemm_arg.BK0);
|
||||
|
||||
const auto M = a_grid_desc_ak0_m_ak1[I0].GetLength(I1);
|
||||
const auto N = b_grid_desc_bk0_n_bk1[I0].GetLength(I1);
|
||||
const auto K =
|
||||
a_grid_desc_ak0_m_ak1[I0].GetLength(I0) * a_grid_desc_ak0_m_ak1[I0].GetLength(I2);
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemmWelford::CheckValidity(gemm_arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
void* p_h,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideH,
|
||||
double epsilon,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
HElementwiseOperation h_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_gamma,
|
||||
p_beta,
|
||||
p_h,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideH,
|
||||
epsilon,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
h_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
void* p_h,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideH,
|
||||
double epsilon,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
HElementwiseOperation h_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_gamma,
|
||||
p_beta,
|
||||
p_h,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideH,
|
||||
epsilon,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
h_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3"
|
||||
<< ">"
|
||||
<< "BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerWmma << "x"<<NPerWmma << ", "
|
||||
<< "WaveMap: "
|
||||
<< MRepeat << "x" << NRepeat << ", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
|
||||
<< "GemmSpec: "
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< "VmemWriteThreadCluster: "
|
||||
<< CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I1) << ", "
|
||||
<< CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3) << ", "
|
||||
<< "LayerNormThreadCluster: "
|
||||
<< LayernormThreadClusterSize_M_N::At(I0) << ", "
|
||||
<< LayernormThreadClusterSize_M_N::At(I1) << ", "
|
||||
<< "LayerNormThreadSliceSize: "
|
||||
<< LayernormThreadSliceSize_M << ", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemmWelford::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< "KPack: "
|
||||
<< GridwiseGemmWelford::KPack;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -60,8 +60,8 @@ struct AddReluAdd
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
|
||||
half_t& y, const float& x0, const half_t& x1, const half_t& x2) const
|
||||
__host__ __device__ constexpr void operator()<float, float, half_t, half_t>(
|
||||
float& y, const float& x0, const half_t& x1, const half_t& x2) const
|
||||
{
|
||||
float a = x0 + x1;
|
||||
float b = a > 0 ? a : 0;
|
||||
@@ -69,6 +69,15 @@ struct AddReluAdd
|
||||
y = c;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
|
||||
half_t& y, const float& x0, const half_t& x1, const half_t& x2) const
|
||||
{
|
||||
float y_float;
|
||||
(*this)(y_float, x0, x1, x2);
|
||||
y = y_float;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
|
||||
bhalf_t& y, const float& x0, const bhalf_t& x1, const bhalf_t& x2) const
|
||||
|
||||
@@ -0,0 +1,510 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ThisThreadBlock,
|
||||
typename BlockwiseGemmPipe,
|
||||
index_t BlockSize>
|
||||
struct EpilogueWelfordCShuffle
|
||||
: EpilogueCShuffleBase<DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>
|
||||
{
|
||||
using Base = EpilogueCShuffleBase<
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>;
|
||||
|
||||
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
|
||||
using Base::GetCShuffleLDSDescriptor;
|
||||
using Base::GetVgprToLDSEpilogueDescriptor;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::I2;
|
||||
using Base::I3;
|
||||
using Base::NumDTensor;
|
||||
|
||||
template <typename DoPads, index_t MPerTile, index_t NPerTile>
|
||||
__host__ __device__ static auto MakeMeanVarDescriptor_M_N(index_t M, index_t N)
|
||||
{
|
||||
const auto grid_desc_m_n =
|
||||
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
|
||||
return tensor_operation::device::PadTensorDescriptor(
|
||||
grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
|
||||
}
|
||||
|
||||
template <typename DoPads, index_t MPerTile, index_t NPerTile>
|
||||
__host__ __device__ static auto MakeCountDescriptor_M_N(index_t M, index_t N)
|
||||
{
|
||||
// We will broadcast [N] to [M, N] in this descriptor
|
||||
// Hence, 1st stride is 0
|
||||
const auto grid_desc_m_n =
|
||||
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1));
|
||||
return tensor_operation::device::PadTensorDescriptor(
|
||||
grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
|
||||
}
|
||||
|
||||
template <typename GridDescriptor_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
|
||||
{
|
||||
const auto M = grid_desc_m_n.GetLength(I0);
|
||||
const auto NBlock = grid_desc_m_n.GetLength(I1);
|
||||
const auto MBlock = M / MPerBlock;
|
||||
|
||||
const auto grid_desc_mblock_mperblock_nblock = transform_tensor_descriptor(
|
||||
grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
||||
make_pass_through_transform(NBlock)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}));
|
||||
|
||||
return grid_desc_mblock_mperblock_nblock;
|
||||
}
|
||||
|
||||
using GemmMeanVarGridDesc_M_N =
|
||||
decltype(MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(1, 1));
|
||||
|
||||
using GemmCountGridDesc_M_N =
|
||||
decltype(MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(1, 1));
|
||||
|
||||
__device__ EpilogueWelfordCShuffle(EDataType* p_welford_mean_grid_,
|
||||
EDataType* p_welford_var_grid_,
|
||||
int32_t* p_welford_count_grid_,
|
||||
index_t MRaw_,
|
||||
index_t NRaw_)
|
||||
: p_welford_mean_grid(p_welford_mean_grid_),
|
||||
p_welford_var_grid(p_welford_var_grid_),
|
||||
p_welford_count_grid(p_welford_count_grid_),
|
||||
NRaw(NRaw_)
|
||||
{
|
||||
index_t gemm_nblock = math::integer_divide_ceil(NRaw_, NPerBlock);
|
||||
|
||||
gemm_mean_var_grid_desc_m_nblock =
|
||||
MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(MRaw_, gemm_nblock);
|
||||
|
||||
gemm_count_grid_desc_m_nblock =
|
||||
MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(MRaw_, gemm_nblock);
|
||||
}
|
||||
|
||||
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
typename CThreadBuf,
|
||||
typename DsGridPointer,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
|
||||
__device__ void Run(CThreadBuf& c_thread_buf,
|
||||
DsGridPointer p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
CDEElementwiseOperation& cde_element_op,
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id)
|
||||
{
|
||||
// Vmem buffers
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ds_grid[i],
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
auto mean_var_grid_desc_mblock_mperblock_nblock =
|
||||
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
|
||||
gemm_mean_var_grid_desc_m_nblock);
|
||||
|
||||
auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_mean_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
|
||||
|
||||
auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_var_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
|
||||
|
||||
auto count_grid_desc_mblock_mperblock_nblock =
|
||||
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(gemm_count_grid_desc_m_nblock);
|
||||
auto welford_count_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_count_grid, count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
|
||||
|
||||
// LDS buffer
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
|
||||
.GetElementSpaceSize());
|
||||
|
||||
// tuple of reference to C/Ds tensor buffers (mix LDS and Vmem)
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// Thread transfer Vgpr to LDS
|
||||
auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor();
|
||||
|
||||
// Space Filling Curve Vgpr
|
||||
constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{};
|
||||
|
||||
// Space Filling Curve Vmem
|
||||
constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{};
|
||||
|
||||
// C thread descriptor
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
BlockwiseGemmPipe::
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// Thread transfer LDS to Vmem
|
||||
auto cde_shuffle_block_copy_lds_and_global =
|
||||
Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, AccDataType>(
|
||||
c_ds_desc_refs,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id);
|
||||
|
||||
// Block descriptor
|
||||
constexpr auto
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
GetCShuffleLDSDescriptor();
|
||||
|
||||
// E Vgpr buffer
|
||||
constexpr index_t PostShuffleThreadSliceSize_M =
|
||||
(CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) /
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I1);
|
||||
|
||||
constexpr index_t PostShuffleThreadSliceSize_N =
|
||||
(CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) /
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3);
|
||||
|
||||
constexpr auto PostShuffleThreadSliceSize_M_N =
|
||||
Sequence<PostShuffleThreadSliceSize_M, PostShuffleThreadSliceSize_N>{};
|
||||
|
||||
// Welford
|
||||
constexpr auto post_shuffle_thread_desc_m_n =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{},
|
||||
Number<PostShuffleThreadSliceSize_M>{},
|
||||
Number<1>{},
|
||||
Number<PostShuffleThreadSliceSize_N>{}));
|
||||
|
||||
auto e_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
post_shuffle_thread_desc_m_n.GetElementSpaceSize());
|
||||
|
||||
using PostShuffleThreadClusterSize_M_N = Sequence<
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I1),
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3)>;
|
||||
|
||||
constexpr auto post_shuffle_thread_cluster_desc =
|
||||
make_cluster_descriptor(PostShuffleThreadClusterSize_M_N{}, Sequence<0, 1>{});
|
||||
|
||||
const auto post_shuffle_thread_cluster_idx =
|
||||
post_shuffle_thread_cluster_desc.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto post_shuffle_thread_data_idx_begin =
|
||||
post_shuffle_thread_cluster_idx * PostShuffleThreadSliceSize_M_N;
|
||||
|
||||
constexpr auto thread_welford_src_desc_m_k = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<PostShuffleThreadSliceSize_M>{}, Number<PostShuffleThreadSliceSize_N>{}));
|
||||
|
||||
constexpr auto thread_welford_dst_desc_m =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<PostShuffleThreadSliceSize_M>{}));
|
||||
|
||||
using ThreadwiseWelford = ThreadwiseWelford<AccDataType,
|
||||
decltype(thread_welford_src_desc_m_k),
|
||||
decltype(thread_welford_dst_desc_m)>;
|
||||
|
||||
using BlockwiseWelford = BlockwiseWelford<AccDataType,
|
||||
BlockSize,
|
||||
PostShuffleThreadClusterSize_M_N,
|
||||
Sequence<0, 1>,
|
||||
false>;
|
||||
|
||||
constexpr int num_shuffleM =
|
||||
MPerBlock / (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma);
|
||||
|
||||
constexpr int num_shuffleN =
|
||||
NPerBlock / (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma);
|
||||
|
||||
using mean_var_vgpr_type = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
thread_welford_dst_desc_m.GetElementSpaceSize()));
|
||||
|
||||
using welford_count_vgpr_type =
|
||||
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
|
||||
thread_welford_dst_desc_m.GetElementSpaceSize()));
|
||||
|
||||
Array<ThreadwiseWelford, num_shuffleM> threadwise_welfords;
|
||||
Array<mean_var_vgpr_type, num_shuffleM> mean_thread_bufs;
|
||||
Array<mean_var_vgpr_type, num_shuffleM> var_thread_bufs;
|
||||
Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
|
||||
|
||||
int max_count = PostShuffleThreadSliceSize_N * num_shuffleN;
|
||||
const auto nblock = mean_var_grid_desc_mblock_mperblock_nblock.GetLength(I2);
|
||||
|
||||
// tail block
|
||||
if(block_n_id % nblock == nblock - 1)
|
||||
{
|
||||
constexpr index_t NPerShuffleBlock =
|
||||
CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma;
|
||||
|
||||
int NPerBlockTail = NRaw - NPerBlock * (nblock - 1);
|
||||
int thread_max_len =
|
||||
PostShuffleThreadSliceSize_N * (post_shuffle_thread_cluster_idx[I1] + 1);
|
||||
int shuffle_step = 0;
|
||||
while(thread_max_len <= NPerBlockTail && shuffle_step < num_shuffleN)
|
||||
{
|
||||
++shuffle_step;
|
||||
thread_max_len += NPerShuffleBlock;
|
||||
}
|
||||
|
||||
int delta = 0;
|
||||
if(thread_max_len - NPerBlockTail > PostShuffleThreadSliceSize_N)
|
||||
delta = 0;
|
||||
else if(NPerBlockTail > thread_max_len)
|
||||
delta = PostShuffleThreadSliceSize_N;
|
||||
else
|
||||
delta = PostShuffleThreadSliceSize_N - thread_max_len + NPerBlockTail;
|
||||
|
||||
max_count = shuffle_step * PostShuffleThreadSliceSize_N + delta;
|
||||
}
|
||||
|
||||
// Initialize Welford
|
||||
static_for<0, num_shuffleM, 1>{}([&](auto i) {
|
||||
threadwise_welfords(i).max_count_ = max_count;
|
||||
mean_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
thread_welford_dst_desc_m.GetElementSpaceSize());
|
||||
|
||||
var_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
thread_welford_dst_desc_m.GetElementSpaceSize());
|
||||
|
||||
welford_count_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
|
||||
thread_welford_dst_desc_m.GetElementSpaceSize());
|
||||
|
||||
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
|
||||
mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
|
||||
var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
|
||||
welford_count_thread_bufs(i)(j) = 0;
|
||||
});
|
||||
});
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
// Run CShuffle + Store E + Welford threadwise
|
||||
int shuffleM_index = __builtin_amdgcn_readfirstlane(0);
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each thread shuffle data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(
|
||||
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
// Read LDS / Vmem + CDE elementwise operation
|
||||
cde_shuffle_block_copy_lds_and_global.RunRead(c_ds_desc_refs, c_ds_buf_refs);
|
||||
|
||||
// Store to Vmem, but keep data in Vgpr for Welford
|
||||
cde_shuffle_block_copy_lds_and_global.RunWriteAndStoreVgpr(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e_grid_buf),
|
||||
tie(post_shuffle_thread_desc_m_n),
|
||||
tie(e_thread_buf));
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
// move on Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
|
||||
c_ds_desc_refs, i + I1, cde_global_step);
|
||||
});
|
||||
|
||||
// move on E
|
||||
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
|
||||
}
|
||||
|
||||
// Threadwise welford
|
||||
auto& threadwise_welford = threadwise_welfords(shuffleM_index);
|
||||
auto& mean_thread_buf = mean_thread_bufs(shuffleM_index);
|
||||
auto& var_thread_buf = var_thread_bufs(shuffleM_index);
|
||||
|
||||
threadwise_welford.Run(e_thread_buf, mean_thread_buf, var_thread_buf);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto de_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
constexpr int shuffleMInc =
|
||||
de_global_step[I1] /
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
|
||||
I1);
|
||||
shuffleM_index = __builtin_amdgcn_readfirstlane(shuffleM_index + shuffleMInc);
|
||||
}
|
||||
});
|
||||
|
||||
// Blockwise welford and write out
|
||||
static_for<0, num_shuffleM, 1>{}([&](auto i) {
|
||||
auto& mean_thread_buf = mean_thread_bufs(i);
|
||||
auto& var_thread_buf = var_thread_bufs(i);
|
||||
auto& count_thread_buf = welford_count_thread_bufs(i);
|
||||
|
||||
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
|
||||
block_sync_lds();
|
||||
count_thread_buf(j) = threadwise_welfords(i).cur_count_;
|
||||
BlockwiseWelford::Run(mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j));
|
||||
});
|
||||
|
||||
if(post_shuffle_thread_cluster_idx[I1] == 0)
|
||||
{
|
||||
constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<PostShuffleThreadSliceSize_M>{}, I1));
|
||||
|
||||
constexpr int shuffleMPerBlock =
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
|
||||
I1);
|
||||
|
||||
auto mean_var_count_thread_copy_index = make_multi_index(
|
||||
block_m_id, // mblock
|
||||
shuffleMPerBlock * i + post_shuffle_thread_data_idx_begin[I0], // mperblock
|
||||
block_n_id); // nblock
|
||||
|
||||
auto mean_var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
EDataType,
|
||||
decltype(thread_welford_desc_I_m_I),
|
||||
decltype(mean_var_grid_desc_mblock_mperblock_nblock),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
|
||||
Sequence<0, 1, 2>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{mean_var_grid_desc_mblock_mperblock_nblock,
|
||||
mean_var_count_thread_copy_index,
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
mean_var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
|
||||
make_tuple(I0, I0, I0),
|
||||
mean_thread_buf,
|
||||
mean_var_grid_desc_mblock_mperblock_nblock,
|
||||
mean_grid_buf); // write mean
|
||||
|
||||
mean_var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
|
||||
make_tuple(I0, I0, I0),
|
||||
var_thread_buf,
|
||||
mean_var_grid_desc_mblock_mperblock_nblock,
|
||||
var_grid_buf); // write variance
|
||||
|
||||
// Stride of count is [0, 1]. Only the first row in count[0, 0:nblock] need
|
||||
// to be written.
|
||||
if(i == 0 && block_m_id == 0 && post_shuffle_thread_cluster_idx[I0] == 0)
|
||||
{
|
||||
auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
int32_t,
|
||||
int32_t,
|
||||
decltype(thread_welford_desc_I_m_I),
|
||||
decltype(count_grid_desc_mblock_mperblock_nblock),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
|
||||
Sequence<0, 1, 2>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
false>{count_grid_desc_mblock_mperblock_nblock,
|
||||
mean_var_count_thread_copy_index,
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
count_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
|
||||
make_tuple(I0, I0, I0),
|
||||
count_thread_buf,
|
||||
count_grid_desc_mblock_mperblock_nblock,
|
||||
welford_count_grid_buf); // write count
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
EDataType* p_welford_mean_grid;
|
||||
EDataType* p_welford_var_grid;
|
||||
int32_t* p_welford_count_grid;
|
||||
index_t NRaw;
|
||||
GemmMeanVarGridDesc_M_N gemm_mean_var_grid_desc_m_nblock;
|
||||
GemmCountGridDesc_M_N gemm_count_grid_desc_m_nblock;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,195 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ThisThreadBlock,
|
||||
typename BlockwiseGemmPipe>
|
||||
struct EpilogueCShuffle
|
||||
: EpilogueCShuffleBase<DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>
|
||||
{
|
||||
using Base = EpilogueCShuffleBase<
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>;
|
||||
|
||||
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
|
||||
using Base::GetCShuffleLDSDescriptor;
|
||||
using Base::GetVgprToLDSEpilogueDescriptor;
|
||||
using Base::I1;
|
||||
using Base::NumDTensor;
|
||||
|
||||
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
typename CThreadBuf,
|
||||
typename DsGridPointer,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
|
||||
__device__ static void Run(CThreadBuf& c_thread_buf,
|
||||
DsGridPointer p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
CDEElementwiseOperation& cde_element_op,
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id)
|
||||
{
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ds_grid[i],
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// C mapping in single thread.
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
BlockwiseGemmPipe::
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
// LDS buffer
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
|
||||
.GetElementSpaceSize());
|
||||
|
||||
// Thread transfer Vgpr to LDS
|
||||
auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor();
|
||||
|
||||
// Space Filling Curve Vgpr
|
||||
constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{};
|
||||
|
||||
// Space Filling Curve Vmem
|
||||
constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{};
|
||||
|
||||
// Block descriptor
|
||||
constexpr auto
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
GetCShuffleLDSDescriptor();
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// Thread transfer LDS to Vmem
|
||||
auto cde_shuffle_block_copy_lds_and_global =
|
||||
Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, EDataType>(
|
||||
c_ds_desc_refs,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id);
|
||||
|
||||
// tuple of reference to C/Ds tensor buffers
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
// CShuffle and Store
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(
|
||||
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block loads its C data from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
cde_shuffle_block_copy_lds_and_global.Run(
|
||||
c_ds_desc_refs,
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e_grid_buf));
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
// move on Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
|
||||
c_ds_desc_refs, i + I1, cde_global_step);
|
||||
});
|
||||
|
||||
// move on E
|
||||
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,253 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ThisThreadBlock,
|
||||
typename BlockwiseGemmPipe>
|
||||
struct EpilogueCShuffleBase
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr auto EShuffleBlockTransferScalarPerVector =
|
||||
CDEShuffleBlockTransferScalarPerVectors{}[I0];
|
||||
|
||||
using SpaceFillingCurveVgpr =
|
||||
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, BlockwiseGemmPipe::MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
1,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
1,
|
||||
1,
|
||||
BlockwiseGemmPipe::MAccVgprs>>;
|
||||
|
||||
using SpaceFillingCurveVmem = SpaceFillingCurve<
|
||||
Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma>>;
|
||||
|
||||
// *Caution Here repeat is shuffle repeat
|
||||
__device__ static constexpr auto
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
|
||||
I1,
|
||||
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
|
||||
|
||||
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetCShuffleLDSDescriptor()
|
||||
{
|
||||
// C mapping in single block
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
|
||||
BlockwiseGemmPipe::
|
||||
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
constexpr auto MWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I1);
|
||||
constexpr auto MSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I2);
|
||||
constexpr auto NWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I4);
|
||||
constexpr auto NThreadPerSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I5);
|
||||
constexpr auto MAccVgprs =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I6);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(),
|
||||
make_tuple(make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
|
||||
MWave, // MWave
|
||||
MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
|
||||
MAccVgprs)),
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
|
||||
NWave, // NWave
|
||||
NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
__device__ static auto GetVgprToLDSEpilogueDescriptor()
|
||||
{
|
||||
// C mapping in single block
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
|
||||
BlockwiseGemmPipe::
|
||||
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
constexpr auto MWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I1);
|
||||
constexpr auto MSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I2);
|
||||
constexpr auto NWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I4);
|
||||
constexpr auto NThreadPerSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I5);
|
||||
constexpr auto MAccVgprs =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I6);
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
BlockwiseGemmPipe::CalculateCThreadOriginDataIndex(I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
|
||||
.CalculateBottomIndex(make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
return ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(BlockwiseGemmPipe::
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()),
|
||||
decltype(GetCShuffleLDSDescriptor()),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{GetCShuffleLDSDescriptor(),
|
||||
make_multi_index(0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
0,
|
||||
n_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
}
|
||||
|
||||
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
typename InterDataType,
|
||||
typename CDsDescRefs,
|
||||
typename EGridDesc>
|
||||
__device__ static auto
|
||||
GetLDSToVmemEpilogueDescriptor(CDsDescRefs& c_ds_desc_refs,
|
||||
EGridDesc& e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
CDEElementwiseOperation& cde_element_op,
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id)
|
||||
{
|
||||
// tuple of starting index of C/Ds blockwise copy
|
||||
const auto idx_c_ds_block_begin = container_concat(
|
||||
make_tuple(make_multi_index(0, 0, 0, 0)),
|
||||
generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// blockwise copy which loads C from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
return ThreadGroupTensorSliceTransfer_v7r3<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
Tuple<EDataType>,
|
||||
CDsDescRefs,
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CDEElementwiseOperation, // ElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves *
|
||||
NPerWmma>, // BlockSliceLengths,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
|
||||
Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
|
||||
Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
|
||||
3, // SrcVectorDim,
|
||||
3, // DstVectorDim,
|
||||
CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
|
||||
EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
1,
|
||||
Tuple<InterDataType>>{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
|
||||
cde_element_op};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -315,8 +315,6 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
using Base::MakeDsGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
|
||||
|
||||
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
@@ -556,7 +554,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(AsGridPointer& p_as_grid,
|
||||
BsGridPointer& p_bs_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
@@ -565,7 +564,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
@@ -610,6 +610,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid,
|
||||
@@ -627,16 +628,20 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
b_scale_struct);
|
||||
b_scale_struct,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
// Wrapper function to have __global__ function in common
|
||||
// between gemm_universal, b_scale, ab_scale, etc.
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
__device__ static void
|
||||
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
AsGridPointer p_as_grid_splitk;
|
||||
@@ -663,7 +668,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -209,8 +209,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
using Base::MakeDsGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
|
||||
|
||||
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
@@ -533,7 +531,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(AsGridPointer& p_as_grid,
|
||||
BsGridPointer& p_bs_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
@@ -543,7 +542,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
@@ -593,6 +593,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid,
|
||||
@@ -610,16 +611,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
b_scale_struct);
|
||||
b_scale_struct,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
// NOTE: Wrapper function to have __global__ function in common
|
||||
// between gemm_universal, b_scale, ab_scale, etc.
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
__device__ static void
|
||||
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
AsGridPointer p_as_grid_splitk;
|
||||
@@ -647,7 +652,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -23,6 +23,8 @@
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -46,12 +48,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg);
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
@@ -262,9 +268,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
static_assert(!PermuteA, "PermuteA is not supported");
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
// if arch = gfx942
|
||||
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
@@ -539,23 +543,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
// *Caution Here repeat is shuffle repeat
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
|
||||
I1,
|
||||
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
|
||||
|
||||
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
}
|
||||
|
||||
using BlockwiseGemmPipe =
|
||||
remove_cvref_t<decltype(BlockGemmPipeline_Selector<BlkGemmPipelineVer,
|
||||
BlkGemmPipeSched,
|
||||
@@ -578,6 +565,46 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
NRepeat,
|
||||
KPack>())>;
|
||||
|
||||
// Used to create obj in global function and pass it to Run method
|
||||
using EpilogueCShuffle =
|
||||
EpilogueCShuffle<DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>;
|
||||
|
||||
using EpilogueWelfordCShuffle = EpilogueWelfordCShuffle<
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe,
|
||||
BlockSize>;
|
||||
|
||||
template <typename DEGridDesc>
|
||||
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
@@ -821,6 +848,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
|
||||
}
|
||||
|
||||
template <typename EpilogueType>
|
||||
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
@@ -838,7 +866,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
// LDS allocation for C shuffle in LDS
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
EpilogueType::
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
|
||||
@@ -867,6 +896,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename BScaleStruct,
|
||||
typename EpilogueArgument,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
@@ -887,7 +917,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id,
|
||||
const index_t& num_k_block_per_scale,
|
||||
BScaleStruct& b_scale_struct)
|
||||
BScaleStruct& b_scale_struct,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const auto as_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -903,16 +934,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ds_grid[i],
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
|
||||
|
||||
@@ -984,240 +1005,16 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
num_k_block_per_scale);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
// C mapping in single thread.
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
blockwise_gemm_pipeline
|
||||
.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
// C mapping in single block
|
||||
constexpr auto
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
|
||||
blockwise_gemm_pipeline
|
||||
.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
constexpr auto MWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I1);
|
||||
constexpr auto MSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I2);
|
||||
constexpr auto NWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I4);
|
||||
constexpr auto NThreadPerSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I5);
|
||||
constexpr auto MAccVgprs =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I6);
|
||||
|
||||
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
|
||||
.GetElementSpaceSize());
|
||||
|
||||
constexpr auto
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
|
||||
MWave, // MWave
|
||||
MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
|
||||
MAccVgprs)),
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
|
||||
NWave, // NWave
|
||||
NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{},
|
||||
Sequence<0, 1, 2, 6>{},
|
||||
Sequence<>{},
|
||||
Sequence<3, 4, 5>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(
|
||||
MRepeat, MWave, MSubGroup, MAccVgprs))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
|
||||
.CalculateBottomIndex(make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(
|
||||
NRepeat, NWave, NThreadPerSubGroup))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor
|
||||
.CalculateBottomIndex(make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
1, // vector write pixel
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
make_multi_index(0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
0,
|
||||
n_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of reference to C/Ds tensor buffers
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of starting index of C/Ds blockwise copy
|
||||
const auto idx_c_ds_block_begin = container_concat(
|
||||
make_tuple(make_multi_index(0, 0, 0, 0)),
|
||||
generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// blockwise copy which loads C from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
auto cde_shuffle_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
Tuple<EDataType>,
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CDEElementwiseOperation, // ElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
|
||||
Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
|
||||
Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
|
||||
3, // SrcVectorDim,
|
||||
3, // DstVectorDim,
|
||||
CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
|
||||
EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
|
||||
cde_element_op};
|
||||
|
||||
// space filling curve for local reg & global memory
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
1,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
1,
|
||||
1,
|
||||
MAccVgprs>>{};
|
||||
|
||||
// space filling curve for shuffled blockwise C in global mem
|
||||
constexpr auto sfc_cde_global =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(
|
||||
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block loads its C data from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
cde_shuffle_block_copy_lds_and_global.Run(
|
||||
c_ds_desc_refs,
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e_grid_buf));
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
// move on Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
|
||||
c_ds_desc_refs, i + I1, cde_global_step);
|
||||
});
|
||||
|
||||
// move on E
|
||||
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
epilogue_args.template Run<EGlobalMemoryDataOperation>(
|
||||
c_thread_buf,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -43,7 +43,8 @@ template <typename SrcDatas,
|
||||
index_t DstScalarPerVector,
|
||||
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
index_t NumThreadScratch = 1>
|
||||
index_t NumThreadScratch = 1,
|
||||
typename InterDatas = DstDatas>
|
||||
struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -153,7 +154,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
// loop over space-filling curve
|
||||
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
|
||||
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
|
||||
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
|
||||
auto elm_vectors = generate_vectors<InterDatas, SrcScalarPerVector>();
|
||||
|
||||
bool oob_val = true;
|
||||
|
||||
@@ -226,9 +227,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iDst) -> auto& {
|
||||
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
|
||||
using InterData = remove_cvref_t<tuple_element_t<iDst.value, InterDatas>>;
|
||||
|
||||
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
using elem_op_vec_t =
|
||||
typename vector_type<InterData, elem_op_vec_len>::type;
|
||||
|
||||
return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
|
||||
},
|
||||
@@ -297,17 +299,17 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
__device__ void
|
||||
TransposeFromElmToDst(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
using DstData = remove_cvref_t<decltype(DstDatas{}[I0])>;
|
||||
using InterData = remove_cvref_t<decltype(InterDatas{}[I0])>;
|
||||
|
||||
using ElmThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
InterData,
|
||||
SrcScalarPerVector,
|
||||
decltype(GetSrcThreadScratchDescriptor()),
|
||||
true>;
|
||||
using DstThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
InterData,
|
||||
DstScalarPerVector,
|
||||
decltype(GetDstThreadScratchDescriptor()),
|
||||
true>;
|
||||
@@ -319,11 +321,11 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_[thread_scratch_id]);
|
||||
|
||||
if constexpr(SrcVectorDim != DstVectorDim &&
|
||||
((is_same<half_t, remove_cvref_t<DstData>>::value &&
|
||||
((is_same<half_t, remove_cvref_t<InterData>>::value &&
|
||||
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
|
||||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
|
||||
(is_same<f8_t, remove_cvref_t<InterData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
|
||||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
|
||||
(is_same<int8_t, remove_cvref_t<InterData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
|
||||
{
|
||||
// each transpose does
|
||||
@@ -356,8 +358,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
constexpr auto data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
|
||||
|
||||
using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
|
||||
using src_vector_t = vector_type_maker_t<InterData, SrcScalarPerVector>;
|
||||
using dst_vector_t = vector_type_maker_t<InterData, DstScalarPerVector>;
|
||||
|
||||
// get DstScalarPerVector # of read-only references to src vectors from
|
||||
// src_thread_scratch_
|
||||
@@ -380,7 +382,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
Number<num_dst_vector>{});
|
||||
|
||||
// do data transpose
|
||||
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
transpose_vectors<InterData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
src_vector_refs, dst_vector_refs);
|
||||
});
|
||||
}
|
||||
@@ -393,6 +395,104 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
dst_vectors_tuple_(thread_scratch_id) = bit_cast<DstVectorTuple>(dst_thread_scratch_.data_);
|
||||
}
|
||||
|
||||
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
|
||||
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
|
||||
// DstVgprDescs: Tuple<const DstVgprDesc0&, const DstVgprDesc1&, ...>
|
||||
// DstVgprBuffers: Tuple<DstVgprBuffer0&, DstVgprBuffer1&, ...>
|
||||
template <typename DstBuffers,
|
||||
typename DstVgprDescs,
|
||||
typename DstVgprBuffers,
|
||||
index_t ThreadScratchId = 0,
|
||||
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
|
||||
__device__ void
|
||||
RunWriteAndStoreVgpr(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
const DstVgprDescs&,
|
||||
DstVgprBuffers dst_vgpr_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// Same functionality of RunWrite but additionally store internal Vgpr in dst_vgpr_buf
|
||||
OOBCheck(thread_scratch_id);
|
||||
TransposeFromElmToDst(thread_scratch_id);
|
||||
|
||||
// Vgpr buffer origin is set internally to 0
|
||||
constexpr auto dst_slice_origin_idx =
|
||||
generate_tuple([&](auto) { return I0; }, Number<nDim>{});
|
||||
constexpr auto dst_scalar_step_in_vector =
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
// loop over space-filling curve
|
||||
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
|
||||
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
// copy data from buf_vectors into dst_bufs
|
||||
using DstData = remove_cvref_t<decltype(DstDatas{}[i])>;
|
||||
using InterData = remove_cvref_t<decltype(InterDatas{}[i])>;
|
||||
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
using dst_vector_t =
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto j) {
|
||||
dst_vector.template AsType<DstData>()(j) =
|
||||
type_convert<DstData>(dst_vectors[i].template AsType<InterData>()[j]);
|
||||
});
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
|
||||
dst_coords_[i]);
|
||||
|
||||
constexpr InMemoryDataOperationEnum DstInMemOp =
|
||||
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
|
||||
|
||||
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_coords_[i].GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
// store Vgpr
|
||||
using DstVgprDesc = remove_cvref_t<decltype(DstVgprDescs{}.At(i))>;
|
||||
static_assert(DstVgprDesc::IsKnownAtCompileTime(),
|
||||
"wrong! DstDesc need to known at compile-time");
|
||||
constexpr auto dst_vgpr_desc = DstVgprDesc{};
|
||||
|
||||
constexpr auto src_data_idx = DstSpaceFillingCurve::GetIndex(iAccess);
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr index_t dst_offset =
|
||||
dst_vgpr_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) +
|
||||
src_data_idx + j * dst_scalar_step_in_vector);
|
||||
|
||||
dst_vgpr_buf(I0)(Number<dst_offset>{}) =
|
||||
is_dst_valid ? dst_vectors[i].template AsType<InterData>()[j]
|
||||
: NumericLimits<InterData>::QuietNaN();
|
||||
});
|
||||
});
|
||||
|
||||
// move coordinate
|
||||
if constexpr(iAccess.value != dst_num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess);
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
move_tensor_coordinate(dst_descs[i],
|
||||
dst_coords_(i),
|
||||
make_tensor_coordinate_step(dst_descs[i], forward_step));
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
if constexpr(DstResetCoordinateAfterRunFlags::At(i))
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
|
||||
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
|
||||
template <typename DstBuffers,
|
||||
@@ -402,6 +502,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
DstBuffers dst_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
static_assert(is_same_v<InterDatas, DstDatas>,
|
||||
"RunWrite doesn't support inter data type different from dst data type");
|
||||
|
||||
OOBCheck(thread_scratch_id);
|
||||
TransposeFromElmToDst(thread_scratch_id);
|
||||
|
||||
@@ -630,8 +733,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
|
||||
private:
|
||||
using SrcVectorsType = decltype(generate_vectors<SrcDatas, SrcScalarPerVector>());
|
||||
using ElmVectorsType = decltype(generate_vectors<DstDatas, SrcScalarPerVector>());
|
||||
using DstVectorsType = decltype(generate_vectors<DstDatas, DstScalarPerVector>());
|
||||
using ElmVectorsType = decltype(generate_vectors<InterDatas, SrcScalarPerVector>());
|
||||
using DstVectorsType = decltype(generate_vectors<InterDatas, DstScalarPerVector>());
|
||||
|
||||
static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
|
||||
static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
0
include/ck_tile/host/tensor_shuffle_utils.hpp
Executable file → Normal file
0
include/ck_tile/host/tensor_shuffle_utils.hpp
Executable file → Normal file
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -15,6 +15,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
|
||||
Row,
|
||||
@@ -78,6 +79,73 @@ void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_ins
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>&);
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>&);
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>&);
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Col,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>&);
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Col,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>&);
|
||||
#endif
|
||||
|
||||
// GEMM + Add + Relu + Add + Layernorm
|
||||
template <typename ALayout,
|
||||
@@ -136,29 +204,53 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<HLayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<HLayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<HLayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<HLayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_gemm_add_relu_add_layernorm_instance
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using F16_F16_Tuple = ck::Tuple<F16, F16>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Row_Row_Tuple = ck::Tuple<Row, Row>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// h = layernorm(e, gamma, beta)
|
||||
// output: h[m, n]
|
||||
// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n]
|
||||
template <BlockGemmPipelineScheduler GemmLoopScheduler, BlockGemmPipelineVersion GemmPipeline>
|
||||
using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline|
|
||||
//##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | |
|
||||
//##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | |
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 2, 2, 16, 16, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 2, 2, 16, 16, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// irregular tile size
|
||||
template <BlockGemmPipelineScheduler GemmLoopScheduler, BlockGemmPipelineVersion GemmPipeline>
|
||||
using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_irregular_tile_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline|
|
||||
//##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | |
|
||||
//##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | |
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 64, 32, 32, 32, 8, 8, 16, 16, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Col,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_irregular_tile_instances<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,108 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using F16_F16_Tuple = ck::Tuple<F16, F16>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Row_Row_Tuple = ck::Tuple<Row, Row>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// h = layernorm(e, gamma, beta)
|
||||
// output: h[m, n]
|
||||
// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n]
|
||||
template <BlockGemmPipelineScheduler GemmLoopScheduler, BlockGemmPipelineVersion GemmPipeline>
|
||||
using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline|
|
||||
//##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | |
|
||||
//##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | |
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 4, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 8, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 2, 2, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 4, 2, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 2, 2, 16, 16, 4, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 2, 2, 16, 16, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// irregular tile size
|
||||
template <BlockGemmPipelineScheduler GemmLoopScheduler, BlockGemmPipelineVersion GemmPipeline>
|
||||
using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_irregular_tile_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline|
|
||||
//##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | |
|
||||
//##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | |
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 64, 32, 32, 32, 8, 8, 16, 16, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Col,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_irregular_tile_instances<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,108 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using F16_F16_Tuple = ck::Tuple<F16, F16>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Row_Row_Tuple = ck::Tuple<Row, Row>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// h = layernorm(e, gamma, beta)
|
||||
// output: h[m, n]
|
||||
// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n]
|
||||
template <BlockGemmPipelineScheduler GemmLoopScheduler, BlockGemmPipelineVersion GemmPipeline>
|
||||
using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline|
|
||||
//##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | |
|
||||
//##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | |
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 2, 2, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 2, 2, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// irregular tile size
|
||||
template <BlockGemmPipelineScheduler GemmLoopScheduler, BlockGemmPipelineVersion GemmPipeline>
|
||||
using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_irregular_tile_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline|
|
||||
//##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | |
|
||||
//##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | |
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 64, 32, 32, 32, 8, 8, 16, 16, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_irregular_tile_instances<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,105 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using F16_F16_Tuple = ck::Tuple<F16, F16>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Row_Row_Tuple = ck::Tuple<Row, Row>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// h = layernorm(e, gamma, beta)
|
||||
// output: h[m, n]
|
||||
// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n]
|
||||
template <BlockGemmPipelineScheduler GemmLoopScheduler, BlockGemmPipelineVersion GemmPipeline>
|
||||
using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline|
|
||||
//##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | |
|
||||
//##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | |
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>,
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineScheduler GemmLoopScheduler, BlockGemmPipelineVersion GemmPipeline>
|
||||
using device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_irregular_tile_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| A| B| Ds| H| AData| BData| DsData| HData| AccData| CShuffleData | EMeanVarData| GammaData| BetaData| A| B| CDE| H| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| Layernorm| Layernorm| LoopScheduler| Pipeline|
|
||||
//##########################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Type | Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| ThreadClusterLengths| ThreadSliceSize| | |
|
||||
//##########################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | _M_N| _M| | |
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 64, 32, 32, 32, 8, 8, 16, 16, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddReluAdd,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_add_wmma_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_irregular_tile_instances<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -167,6 +167,12 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
|
||||
Tensor<HDataType> h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
|
||||
Tensor<HDataType> h_m_n_host(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
|
||||
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
|
||||
std::cout << "h_m_n: " << h_m_n.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
@@ -312,9 +318,8 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
if(time_kernel)
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec
|
||||
<< " GB/s, " << op_name << std::endl;
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
|
||||
<< op_name << std::endl;
|
||||
|
||||
if(ave_time < best_ave_time)
|
||||
{
|
||||
@@ -333,8 +338,7 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
|
||||
}
|
||||
else
|
||||
{
|
||||
if(time_kernel)
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_custom_target(test_gemm_layernorm)
|
||||
target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance)
|
||||
add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16)
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_custom_target(test_gemm_layernorm)
|
||||
target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance)
|
||||
add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -79,11 +79,6 @@ TYPED_TEST_SUITE(TestGemmAddReluAddLayernorm, KernelTypes);
|
||||
TYPED_TEST(TestGemmAddReluAddLayernorm, Test_FP16) { this->Run(); }
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
std::cout << "No available instance for gfx11 & gfx12." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
Reference in New Issue
Block a user