WMMA gemm_add_relu_add_layernorm (#2989)

* Summary:

 - Refactor epilogue (with CShuffle) to support fused operations:
    - EpilogueCShuffleBase holds common parts
    - EpilogueCShuffle: runs CShuffle and write out
    - EpilogueWelfordCShuffle: holds Welford specific arguments, runs CShuffle, write out, Welford first part and Welford write out

 - Extend thread transfer v7r3:
    - Support for intermediate data type different from src and dst type
    - New functionality to write to dst buffer and keep data (to be able to use them for additional operations)

* Adress review comments
This commit is contained in:
Enrico Degregori
2025-10-31 19:19:26 +01:00
committed by GitHub
parent e9596228ff
commit 4ebc48a3cd
23 changed files with 2678 additions and 332 deletions

View File

@@ -292,13 +292,15 @@ struct BlockwiseGemmWmmaops_pipeline_base
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
}
static constexpr auto MAccVgprs =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()[I2];
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
return make_naive_tensor_descriptor(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave

View File

@@ -42,7 +42,8 @@ template <typename ThreadGroup,
index_t DstScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags,
index_t NumThreadScratch = 1>
index_t NumThreadScratch = 1,
typename InterDatas = DstDatas>
struct ThreadGroupTensorSliceTransfer_v7r3
{
static constexpr index_t nDim =
@@ -97,7 +98,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
@@ -123,7 +124,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
const SrcBuffers& src_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
@@ -138,7 +139,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
DstBuffers dst_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
@@ -148,6 +149,36 @@ struct ThreadGroupTensorSliceTransfer_v7r3
}
}
template <typename DstBuffers,
typename DstVgprDescs,
typename DstVgprBuffers,
index_t ThreadScratchId = 0>
__device__ void
RunWriteAndStoreVgpr(const DstDescs& dst_descs,
DstBuffers dst_bufs,
const DstVgprDescs& dst_vgpr_desc,
DstVgprBuffers dst_vgpr_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value &&
is_detected<is_tuple, decltype(dst_vgpr_buf)>::value)
threadwise_transfer_.RunWriteAndStoreVgpr(
dst_descs, dst_bufs, dst_vgpr_desc, dst_vgpr_buf, thread_scratch_id);
else if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
threadwise_transfer_.RunWriteAndStoreVgpr(
dst_descs, dst_bufs, dst_vgpr_desc, tie(dst_vgpr_buf), thread_scratch_id);
else if constexpr(is_detected<is_tuple, decltype(dst_vgpr_buf)>::value)
threadwise_transfer_.RunWriteAndStoreVgpr(
dst_descs, tie(dst_bufs), dst_vgpr_desc, dst_vgpr_buf, thread_scratch_id);
else
threadwise_transfer_.RunWriteAndStoreVgpr(
dst_descs, tie(dst_bufs), dst_vgpr_desc, tie(dst_vgpr_buf), thread_scratch_id);
}
}
template <typename SrcBuffers, typename DstBuffers>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
@@ -162,7 +193,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
__device__ void
MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
@@ -179,7 +210,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
__device__ void
MoveDstSliceWindow(const DstDescs& dst_descs, Number<IDst> iDst, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
@@ -212,7 +243,8 @@ struct ThreadGroupTensorSliceTransfer_v7r3
DstScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags,
NumThreadScratch>;
NumThreadScratch,
InterDatas>;
ThreadwiseTransfer threadwise_transfer_;
};