WMMA gemm_add_relu_add_layernorm (#2989)

* Summary:

 - Refactor epilogue (with CShuffle) to support fused operations:
    - EpilogueCShuffleBase holds common parts
    - EpilogueCShuffle: runs CShuffle and write out
    - EpilogueWelfordCShuffle: holds Welford specific arguments, runs CShuffle, write out, Welford first part and Welford write out

 - Extend thread transfer v7r3:
    - Support for intermediate data type different from src and dst type
    - New functionality to write to dst buffer and keep data (to be able to use them for additional operations)

* Adress review comments
This commit is contained in:
Enrico Degregori
2025-10-31 19:19:26 +01:00
committed by GitHub
parent e9596228ff
commit 4ebc48a3cd
23 changed files with 2678 additions and 332 deletions

View File

@@ -0,0 +1,510 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
namespace ck {
template <typename DsDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
index_t MPerBlock,
index_t NPerBlock,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
typename CDEElementwiseOperation,
typename ThisThreadBlock,
typename BlockwiseGemmPipe,
index_t BlockSize>
struct EpilogueWelfordCShuffle
: EpilogueCShuffleBase<DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
MPerBlock,
NPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
CDEElementwiseOperation,
ThisThreadBlock,
BlockwiseGemmPipe>
{
using Base = EpilogueCShuffleBase<
DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
MPerBlock,
NPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
CDEElementwiseOperation,
ThisThreadBlock,
BlockwiseGemmPipe>;
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
using Base::GetCShuffleLDSDescriptor;
using Base::GetVgprToLDSEpilogueDescriptor;
using Base::I0;
using Base::I1;
using Base::I2;
using Base::I3;
using Base::NumDTensor;
template <typename DoPads, index_t MPerTile, index_t NPerTile>
__host__ __device__ static auto MakeMeanVarDescriptor_M_N(index_t M, index_t N)
{
const auto grid_desc_m_n =
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
return tensor_operation::device::PadTensorDescriptor(
grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
}
template <typename DoPads, index_t MPerTile, index_t NPerTile>
__host__ __device__ static auto MakeCountDescriptor_M_N(index_t M, index_t N)
{
// We will broadcast [N] to [M, N] in this descriptor
// Hence, 1st stride is 0
const auto grid_desc_m_n =
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1));
return tensor_operation::device::PadTensorDescriptor(
grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
}
template <typename GridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
{
const auto M = grid_desc_m_n.GetLength(I0);
const auto NBlock = grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto grid_desc_mblock_mperblock_nblock = transform_tensor_descriptor(
grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_pass_through_transform(NBlock)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}));
return grid_desc_mblock_mperblock_nblock;
}
using GemmMeanVarGridDesc_M_N =
decltype(MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(1, 1));
using GemmCountGridDesc_M_N =
decltype(MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(1, 1));
__device__ EpilogueWelfordCShuffle(EDataType* p_welford_mean_grid_,
EDataType* p_welford_var_grid_,
int32_t* p_welford_count_grid_,
index_t MRaw_,
index_t NRaw_)
: p_welford_mean_grid(p_welford_mean_grid_),
p_welford_var_grid(p_welford_var_grid_),
p_welford_count_grid(p_welford_count_grid_),
NRaw(NRaw_)
{
index_t gemm_nblock = math::integer_divide_ceil(NRaw_, NPerBlock);
gemm_mean_var_grid_desc_m_nblock =
MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(MRaw_, gemm_nblock);
gemm_count_grid_desc_m_nblock =
MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(MRaw_, gemm_nblock);
}
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename CThreadBuf,
typename DsGridPointer,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
__device__ void Run(CThreadBuf& c_thread_buf,
DsGridPointer p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
CDEElementwiseOperation& cde_element_op,
const index_t& block_m_id,
const index_t& block_n_id)
{
// Vmem buffers
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto mean_var_grid_desc_mblock_mperblock_nblock =
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
gemm_mean_var_grid_desc_m_nblock);
auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_mean_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_var_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto count_grid_desc_mblock_mperblock_nblock =
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(gemm_count_grid_desc_m_nblock);
auto welford_count_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count_grid, count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
// LDS buffer
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.GetElementSpaceSize());
// tuple of reference to C/Ds tensor buffers (mix LDS and Vmem)
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// Thread transfer Vgpr to LDS
auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor();
// Space Filling Curve Vgpr
constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{};
// Space Filling Curve Vmem
constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{};
// C thread descriptor
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
BlockwiseGemmPipe::
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// Thread transfer LDS to Vmem
auto cde_shuffle_block_copy_lds_and_global =
Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, AccDataType>(
c_ds_desc_refs,
e_grid_desc_mblock_mperblock_nblock_nperblock,
cde_element_op,
block_m_id,
block_n_id);
// Block descriptor
constexpr auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
GetCShuffleLDSDescriptor();
// E Vgpr buffer
constexpr index_t PostShuffleThreadSliceSize_M =
(CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) /
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I1);
constexpr index_t PostShuffleThreadSliceSize_N =
(CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) /
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3);
constexpr auto PostShuffleThreadSliceSize_M_N =
Sequence<PostShuffleThreadSliceSize_M, PostShuffleThreadSliceSize_N>{};
// Welford
constexpr auto post_shuffle_thread_desc_m_n =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{},
Number<PostShuffleThreadSliceSize_M>{},
Number<1>{},
Number<PostShuffleThreadSliceSize_N>{}));
auto e_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
post_shuffle_thread_desc_m_n.GetElementSpaceSize());
using PostShuffleThreadClusterSize_M_N = Sequence<
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I1),
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3)>;
constexpr auto post_shuffle_thread_cluster_desc =
make_cluster_descriptor(PostShuffleThreadClusterSize_M_N{}, Sequence<0, 1>{});
const auto post_shuffle_thread_cluster_idx =
post_shuffle_thread_cluster_desc.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto post_shuffle_thread_data_idx_begin =
post_shuffle_thread_cluster_idx * PostShuffleThreadSliceSize_M_N;
constexpr auto thread_welford_src_desc_m_k = make_naive_tensor_descriptor_packed(make_tuple(
Number<PostShuffleThreadSliceSize_M>{}, Number<PostShuffleThreadSliceSize_N>{}));
constexpr auto thread_welford_dst_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<PostShuffleThreadSliceSize_M>{}));
using ThreadwiseWelford = ThreadwiseWelford<AccDataType,
decltype(thread_welford_src_desc_m_k),
decltype(thread_welford_dst_desc_m)>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
PostShuffleThreadClusterSize_M_N,
Sequence<0, 1>,
false>;
constexpr int num_shuffleM =
MPerBlock / (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma);
constexpr int num_shuffleN =
NPerBlock / (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma);
using mean_var_vgpr_type = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize()));
using welford_count_vgpr_type =
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
thread_welford_dst_desc_m.GetElementSpaceSize()));
Array<ThreadwiseWelford, num_shuffleM> threadwise_welfords;
Array<mean_var_vgpr_type, num_shuffleM> mean_thread_bufs;
Array<mean_var_vgpr_type, num_shuffleM> var_thread_bufs;
Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
int max_count = PostShuffleThreadSliceSize_N * num_shuffleN;
const auto nblock = mean_var_grid_desc_mblock_mperblock_nblock.GetLength(I2);
// tail block
if(block_n_id % nblock == nblock - 1)
{
constexpr index_t NPerShuffleBlock =
CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma;
int NPerBlockTail = NRaw - NPerBlock * (nblock - 1);
int thread_max_len =
PostShuffleThreadSliceSize_N * (post_shuffle_thread_cluster_idx[I1] + 1);
int shuffle_step = 0;
while(thread_max_len <= NPerBlockTail && shuffle_step < num_shuffleN)
{
++shuffle_step;
thread_max_len += NPerShuffleBlock;
}
int delta = 0;
if(thread_max_len - NPerBlockTail > PostShuffleThreadSliceSize_N)
delta = 0;
else if(NPerBlockTail > thread_max_len)
delta = PostShuffleThreadSliceSize_N;
else
delta = PostShuffleThreadSliceSize_N - thread_max_len + NPerBlockTail;
max_count = shuffle_step * PostShuffleThreadSliceSize_N + delta;
}
// Initialize Welford
static_for<0, num_shuffleM, 1>{}([&](auto i) {
threadwise_welfords(i).max_count_ = max_count;
mean_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize());
var_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize());
welford_count_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
thread_welford_dst_desc_m.GetElementSpaceSize());
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
welford_count_thread_bufs(i)(j) = 0;
});
});
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
// Run CShuffle + Store E + Welford threadwise
int shuffleM_index = __builtin_amdgcn_readfirstlane(0);
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to read from LDS
block_sync_lds();
// each thread shuffle data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
c_shuffle_block_buf);
// make sure it's safe to write to LDS
block_sync_lds();
// Read LDS / Vmem + CDE elementwise operation
cde_shuffle_block_copy_lds_and_global.RunRead(c_ds_desc_refs, c_ds_buf_refs);
// Store to Vmem, but keep data in Vgpr for Welford
cde_shuffle_block_copy_lds_and_global.RunWriteAndStoreVgpr(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf),
tie(post_shuffle_thread_desc_m_n),
tie(e_thread_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_global_step);
});
// move on E
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
}
// Threadwise welford
auto& threadwise_welford = threadwise_welfords(shuffleM_index);
auto& mean_thread_buf = mean_thread_bufs(shuffleM_index);
auto& var_thread_buf = var_thread_bufs(shuffleM_index);
threadwise_welford.Run(e_thread_buf, mean_thread_buf, var_thread_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto de_global_step = sfc_cde_global.GetForwardStep(access_id);
constexpr int shuffleMInc =
de_global_step[I1] /
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
I1);
shuffleM_index = __builtin_amdgcn_readfirstlane(shuffleM_index + shuffleMInc);
}
});
// Blockwise welford and write out
static_for<0, num_shuffleM, 1>{}([&](auto i) {
auto& mean_thread_buf = mean_thread_bufs(i);
auto& var_thread_buf = var_thread_bufs(i);
auto& count_thread_buf = welford_count_thread_bufs(i);
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
block_sync_lds();
count_thread_buf(j) = threadwise_welfords(i).cur_count_;
BlockwiseWelford::Run(mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j));
});
if(post_shuffle_thread_cluster_idx[I1] == 0)
{
constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<PostShuffleThreadSliceSize_M>{}, I1));
constexpr int shuffleMPerBlock =
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
I1);
auto mean_var_count_thread_copy_index = make_multi_index(
block_m_id, // mblock
shuffleMPerBlock * i + post_shuffle_thread_data_idx_begin[I0], // mperblock
block_n_id); // nblock
auto mean_var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
EDataType,
decltype(thread_welford_desc_I_m_I),
decltype(mean_var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>{mean_var_grid_desc_mblock_mperblock_nblock,
mean_var_count_thread_copy_index,
tensor_operation::element_wise::PassThrough{}};
mean_var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
mean_thread_buf,
mean_var_grid_desc_mblock_mperblock_nblock,
mean_grid_buf); // write mean
mean_var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
var_thread_buf,
mean_var_grid_desc_mblock_mperblock_nblock,
var_grid_buf); // write variance
// Stride of count is [0, 1]. Only the first row in count[0, 0:nblock] need
// to be written.
if(i == 0 && block_m_id == 0 && post_shuffle_thread_cluster_idx[I0] == 0)
{
auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
int32_t,
int32_t,
decltype(thread_welford_desc_I_m_I),
decltype(count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
false>{count_grid_desc_mblock_mperblock_nblock,
mean_var_count_thread_copy_index,
tensor_operation::element_wise::PassThrough{}};
count_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
count_thread_buf,
count_grid_desc_mblock_mperblock_nblock,
welford_count_grid_buf); // write count
}
}
});
}
EDataType* p_welford_mean_grid;
EDataType* p_welford_var_grid;
int32_t* p_welford_count_grid;
index_t NRaw;
GemmMeanVarGridDesc_M_N gemm_mean_var_grid_desc_m_nblock;
GemmCountGridDesc_M_N gemm_count_grid_desc_m_nblock;
};
} // namespace ck

View File

@@ -0,0 +1,195 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp"
namespace ck {
template <typename DsDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
index_t MPerBlock,
index_t NPerBlock,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
typename CDEElementwiseOperation,
typename ThisThreadBlock,
typename BlockwiseGemmPipe>
struct EpilogueCShuffle
: EpilogueCShuffleBase<DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
MPerBlock,
NPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
CDEElementwiseOperation,
ThisThreadBlock,
BlockwiseGemmPipe>
{
using Base = EpilogueCShuffleBase<
DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
MPerBlock,
NPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
CDEElementwiseOperation,
ThisThreadBlock,
BlockwiseGemmPipe>;
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
using Base::GetCShuffleLDSDescriptor;
using Base::GetVgprToLDSEpilogueDescriptor;
using Base::I1;
using Base::NumDTensor;
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename CThreadBuf,
typename DsGridPointer,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
__device__ static void Run(CThreadBuf& c_thread_buf,
DsGridPointer p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
CDEElementwiseOperation& cde_element_op,
const index_t& block_m_id,
const index_t& block_n_id)
{
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// C mapping in single thread.
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
BlockwiseGemmPipe::
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
// LDS buffer
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.GetElementSpaceSize());
// Thread transfer Vgpr to LDS
auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor();
// Space Filling Curve Vgpr
constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{};
// Space Filling Curve Vmem
constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{};
// Block descriptor
constexpr auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
GetCShuffleLDSDescriptor();
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// Thread transfer LDS to Vmem
auto cde_shuffle_block_copy_lds_and_global =
Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, EDataType>(
c_ds_desc_refs,
e_grid_desc_mblock_mperblock_nblock_nperblock,
cde_element_op,
block_m_id,
block_n_id);
// tuple of reference to C/Ds tensor buffers
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
// CShuffle and Store
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block loads its C data from LDS, D from global, applies elementwise
// operation and stores result E to global
cde_shuffle_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_global_step);
});
// move on E
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
}
});
}
};
} // namespace ck

View File

@@ -0,0 +1,253 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
namespace ck {
template <typename DsDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
index_t MPerBlock,
index_t NPerBlock,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
typename CDEElementwiseOperation,
typename ThisThreadBlock,
typename BlockwiseGemmPipe>
struct EpilogueCShuffleBase
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto EShuffleBlockTransferScalarPerVector =
CDEShuffleBlockTransferScalarPerVectors{}[I0];
using SpaceFillingCurveVgpr =
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, BlockwiseGemmPipe::MAccVgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
Sequence<CShuffleMRepeatPerShuffle,
1,
1,
CShuffleNRepeatPerShuffle,
1,
1,
BlockwiseGemmPipe::MAccVgprs>>;
using SpaceFillingCurveVmem = SpaceFillingCurve<
Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma,
1,
CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma>>;
// *Caution Here repeat is shuffle repeat
__device__ static constexpr auto
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
{
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
I1,
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
}
__device__ static constexpr auto GetCShuffleLDSDescriptor()
{
// C mapping in single block
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
BlockwiseGemmPipe::
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr auto MWave =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I1);
constexpr auto MSubGroup =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I2);
constexpr auto NWave =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I4);
constexpr auto NThreadPerSubGroup =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I5);
constexpr auto MAccVgprs =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I6);
return transform_tensor_descriptor(
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(),
make_tuple(make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
MWave, // MWave
MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
MAccVgprs)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
NWave, // NWave
NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{}));
}
__device__ static auto GetVgprToLDSEpilogueDescriptor()
{
// C mapping in single block
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
BlockwiseGemmPipe::
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr auto MWave =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I1);
constexpr auto MSubGroup =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I2);
constexpr auto NWave =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I4);
constexpr auto NThreadPerSubGroup =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I5);
constexpr auto MAccVgprs =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I6);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
BlockwiseGemmPipe::CalculateCThreadOriginDataIndex(I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
.CalculateBottomIndex(make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
return ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
CShuffleDataType,
decltype(BlockwiseGemmPipe::
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()),
decltype(GetCShuffleLDSDescriptor()),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMRepeatPerShuffle,
I1,
I1,
CShuffleNRepeatPerShuffle,
I1,
I1,
MAccVgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
1,
InMemoryDataOperationEnum::Set,
1,
true>{GetCShuffleLDSDescriptor(),
make_multi_index(0,
m_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
0,
n_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3]),
ck::tensor_operation::element_wise::PassThrough{}};
}
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename InterDataType,
typename CDsDescRefs,
typename EGridDesc>
__device__ static auto
GetLDSToVmemEpilogueDescriptor(CDsDescRefs& c_ds_desc_refs,
EGridDesc& e_grid_desc_mblock_mperblock_nblock_nperblock,
CDEElementwiseOperation& cde_element_op,
const index_t& block_m_id,
const index_t& block_n_id)
{
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); },
Number<NumDTensor>{}));
// blockwise copy which loads C from LDS, D from global, applies elementwise
// operation and stores result E to global
return ThreadGroupTensorSliceTransfer_v7r3<
ThisThreadBlock, // ThreadGroup
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
CDsDescRefs,
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation, // ElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
Sequence<1,
CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma,
1,
CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves *
NPerWmma>, // BlockSliceLengths,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
3, // SrcVectorDim,
3, // DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1,
Tuple<InterDataType>>{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
cde_element_op};
}
};
} // namespace ck

View File

@@ -315,8 +315,6 @@ struct GridwiseGemm_wmma_cshuffle_v3
using Base::MakeDsGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
@@ -556,7 +554,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
TailNumber TailNum,
typename EpilogueArgument>
__device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
@@ -565,7 +564,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
CDEElementwiseOperation cde_element_op,
EpilogueArgument& epilogue_args)
{
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
@@ -610,6 +610,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_as_grid,
@@ -627,16 +628,20 @@ struct GridwiseGemm_wmma_cshuffle_v3
block_m_id,
block_n_id,
num_k_block_per_scale,
b_scale_struct);
b_scale_struct,
epilogue_args);
}
// Wrapper function to have __global__ function in common
// between gemm_universal, b_scale, ab_scale, etc.
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
TailNumber TailNum,
typename EpilogueArgument>
__device__ static void Run(void* p_shared,
const SplitKBatchOffset& splitk_batch_offset,
Argument& karg,
EpilogueArgument& epilogue_args)
{
// shift A matrices pointer for splitk
AsGridPointer p_as_grid_splitk;
@@ -663,7 +668,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op);
karg.cde_element_op,
epilogue_args);
}
};

View File

@@ -209,8 +209,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
using Base::MakeDsGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
@@ -533,7 +531,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
TailNumber TailNum,
typename EpilogueArgument>
__device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
@@ -543,7 +542,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
CDEElementwiseOperation cde_element_op,
EpilogueArgument& epilogue_args)
{
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
@@ -593,6 +593,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_as_grid,
@@ -610,16 +611,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
block_m_id,
block_n_id,
num_k_block_per_scale,
b_scale_struct);
b_scale_struct,
epilogue_args);
}
// NOTE: Wrapper function to have __global__ function in common
// between gemm_universal, b_scale, ab_scale, etc.
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
TailNumber TailNum,
typename EpilogueArgument>
__device__ static void Run(void* p_shared,
const SplitKBatchOffset& splitk_batch_offset,
Argument& karg,
EpilogueArgument& epilogue_args)
{
// shift A matrices pointer for splitk
AsGridPointer p_as_grid_splitk;
@@ -647,7 +652,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op);
karg.cde_element_op,
epilogue_args);
}
};

View File

@@ -23,6 +23,8 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
namespace ck {
@@ -46,12 +48,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg);
p_shared, splitk_batch_offset, karg, epilogue_args);
#if defined(__gfx11__)
}
@@ -262,9 +268,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
static_assert(!PermuteA, "PermuteA is not supported");
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
{
@@ -539,23 +543,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
Number<NumDTensor>{});
}
__host__ __device__ static constexpr auto
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
{
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
I1,
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
}
using BlockwiseGemmPipe =
remove_cvref_t<decltype(BlockGemmPipeline_Selector<BlkGemmPipelineVer,
BlkGemmPipeSched,
@@ -578,6 +565,46 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
NRepeat,
KPack>())>;
// Used to create obj in global function and pass it to Run method
using EpilogueCShuffle =
EpilogueCShuffle<DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
MPerBlock,
NPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
CDEElementwiseOperation,
ThisThreadBlock,
BlockwiseGemmPipe>;
using EpilogueWelfordCShuffle = EpilogueWelfordCShuffle<
DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
MPerBlock,
NPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
CDEElementwiseOperation,
ThisThreadBlock,
BlockwiseGemmPipe,
BlockSize>;
template <typename DEGridDesc>
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
@@ -821,6 +848,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
}
template <typename EpilogueType>
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
@@ -838,7 +866,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
EpilogueType::
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
constexpr auto c_block_size =
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
@@ -867,6 +896,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename BScaleStruct,
typename EpilogueArgument,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
@@ -887,7 +917,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
const index_t& block_m_id,
const index_t& block_n_id,
const index_t& num_k_block_per_scale,
BScaleStruct& b_scale_struct)
BScaleStruct& b_scale_struct,
EpilogueArgument& epilogue_args)
{
const auto as_grid_buf = generate_tuple(
[&](auto i) {
@@ -903,16 +934,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
},
Number<NumBTensor>{});
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
@@ -984,240 +1005,16 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
num_k_block_per_scale);
// shuffle C and write out
{
// C mapping in single thread.
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm_pipeline
.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
// C mapping in single block
constexpr auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
blockwise_gemm_pipeline
.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr auto MWave =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I1);
constexpr auto MSubGroup =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I2);
constexpr auto NWave =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I4);
constexpr auto NThreadPerSubGroup =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I5);
constexpr auto MAccVgprs =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I6);
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.GetElementSpaceSize());
constexpr auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
transform_tensor_descriptor(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
MWave, // MWave
MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
MAccVgprs)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
NWave, // NWave
NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{},
Sequence<0, 1, 2, 6>{},
Sequence<>{},
Sequence<3, 4, 5>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(
MRepeat, MWave, MSubGroup, MAccVgprs))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
.CalculateBottomIndex(make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(
NRepeat, NWave, NThreadPerSubGroup))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor
.CalculateBottomIndex(make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMRepeatPerShuffle,
I1,
I1,
CShuffleNRepeatPerShuffle,
I1,
I1,
MAccVgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
1, // vector write pixel
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
make_multi_index(0,
m_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
0,
n_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3]),
ck::tensor_operation::element_wise::PassThrough{}};
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor buffers
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); },
Number<NumDTensor>{}));
// blockwise copy which loads C from LDS, D from global, applies elementwise
// operation and stores result E to global
auto cde_shuffle_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
ThisThreadBlock, // ThreadGroup
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation, // ElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1,
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
3, // SrcVectorDim,
3, // DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
cde_element_op};
// space filling curve for local reg & global memory
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, MAccVgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
Sequence<CShuffleMRepeatPerShuffle,
1,
1,
CShuffleNRepeatPerShuffle,
1,
1,
MAccVgprs>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_cde_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1,
CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block loads its C data from LDS, D from global, applies elementwise
// operation and stores result E to global
cde_shuffle_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_global_step);
});
// move on E
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
}
});
}
epilogue_args.template Run<EGlobalMemoryDataOperation>(
c_thread_buf,
p_ds_grid,
p_e_grid,
p_shared,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
cde_element_op,
block_m_id,
block_n_id);
}
};