mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
WMMA gemm_add_relu_add_layernorm (#2989)
* Summary:
- Refactor epilogue (with CShuffle) to support fused operations:
- EpilogueCShuffleBase holds common parts
- EpilogueCShuffle: runs CShuffle and write out
- EpilogueWelfordCShuffle: holds Welford specific arguments, runs CShuffle, write out, Welford first part and Welford write out
- Extend thread transfer v7r3:
- Support for intermediate data type different from src and dst type
- New functionality to write to dst buffer and keep data (to be able to use them for additional operations)
* Adress review comments
This commit is contained in:
@@ -0,0 +1,510 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ThisThreadBlock,
|
||||
typename BlockwiseGemmPipe,
|
||||
index_t BlockSize>
|
||||
struct EpilogueWelfordCShuffle
|
||||
: EpilogueCShuffleBase<DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>
|
||||
{
|
||||
using Base = EpilogueCShuffleBase<
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>;
|
||||
|
||||
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
|
||||
using Base::GetCShuffleLDSDescriptor;
|
||||
using Base::GetVgprToLDSEpilogueDescriptor;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::I2;
|
||||
using Base::I3;
|
||||
using Base::NumDTensor;
|
||||
|
||||
template <typename DoPads, index_t MPerTile, index_t NPerTile>
|
||||
__host__ __device__ static auto MakeMeanVarDescriptor_M_N(index_t M, index_t N)
|
||||
{
|
||||
const auto grid_desc_m_n =
|
||||
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1));
|
||||
return tensor_operation::device::PadTensorDescriptor(
|
||||
grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
|
||||
}
|
||||
|
||||
template <typename DoPads, index_t MPerTile, index_t NPerTile>
|
||||
__host__ __device__ static auto MakeCountDescriptor_M_N(index_t M, index_t N)
|
||||
{
|
||||
// We will broadcast [N] to [M, N] in this descriptor
|
||||
// Hence, 1st stride is 0
|
||||
const auto grid_desc_m_n =
|
||||
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1));
|
||||
return tensor_operation::device::PadTensorDescriptor(
|
||||
grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
|
||||
}
|
||||
|
||||
template <typename GridDescriptor_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
|
||||
{
|
||||
const auto M = grid_desc_m_n.GetLength(I0);
|
||||
const auto NBlock = grid_desc_m_n.GetLength(I1);
|
||||
const auto MBlock = M / MPerBlock;
|
||||
|
||||
const auto grid_desc_mblock_mperblock_nblock = transform_tensor_descriptor(
|
||||
grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
||||
make_pass_through_transform(NBlock)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}));
|
||||
|
||||
return grid_desc_mblock_mperblock_nblock;
|
||||
}
|
||||
|
||||
using GemmMeanVarGridDesc_M_N =
|
||||
decltype(MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(1, 1));
|
||||
|
||||
using GemmCountGridDesc_M_N =
|
||||
decltype(MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(1, 1));
|
||||
|
||||
__device__ EpilogueWelfordCShuffle(EDataType* p_welford_mean_grid_,
|
||||
EDataType* p_welford_var_grid_,
|
||||
int32_t* p_welford_count_grid_,
|
||||
index_t MRaw_,
|
||||
index_t NRaw_)
|
||||
: p_welford_mean_grid(p_welford_mean_grid_),
|
||||
p_welford_var_grid(p_welford_var_grid_),
|
||||
p_welford_count_grid(p_welford_count_grid_),
|
||||
NRaw(NRaw_)
|
||||
{
|
||||
index_t gemm_nblock = math::integer_divide_ceil(NRaw_, NPerBlock);
|
||||
|
||||
gemm_mean_var_grid_desc_m_nblock =
|
||||
MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(MRaw_, gemm_nblock);
|
||||
|
||||
gemm_count_grid_desc_m_nblock =
|
||||
MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(MRaw_, gemm_nblock);
|
||||
}
|
||||
|
||||
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
typename CThreadBuf,
|
||||
typename DsGridPointer,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
|
||||
__device__ void Run(CThreadBuf& c_thread_buf,
|
||||
DsGridPointer p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
CDEElementwiseOperation& cde_element_op,
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id)
|
||||
{
|
||||
// Vmem buffers
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ds_grid[i],
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
auto mean_var_grid_desc_mblock_mperblock_nblock =
|
||||
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
|
||||
gemm_mean_var_grid_desc_m_nblock);
|
||||
|
||||
auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_mean_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
|
||||
|
||||
auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_var_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
|
||||
|
||||
auto count_grid_desc_mblock_mperblock_nblock =
|
||||
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(gemm_count_grid_desc_m_nblock);
|
||||
auto welford_count_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_count_grid, count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
|
||||
|
||||
// LDS buffer
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
|
||||
.GetElementSpaceSize());
|
||||
|
||||
// tuple of reference to C/Ds tensor buffers (mix LDS and Vmem)
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// Thread transfer Vgpr to LDS
|
||||
auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor();
|
||||
|
||||
// Space Filling Curve Vgpr
|
||||
constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{};
|
||||
|
||||
// Space Filling Curve Vmem
|
||||
constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{};
|
||||
|
||||
// C thread descriptor
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
BlockwiseGemmPipe::
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// Thread transfer LDS to Vmem
|
||||
auto cde_shuffle_block_copy_lds_and_global =
|
||||
Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, AccDataType>(
|
||||
c_ds_desc_refs,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id);
|
||||
|
||||
// Block descriptor
|
||||
constexpr auto
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
GetCShuffleLDSDescriptor();
|
||||
|
||||
// E Vgpr buffer
|
||||
constexpr index_t PostShuffleThreadSliceSize_M =
|
||||
(CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) /
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I1);
|
||||
|
||||
constexpr index_t PostShuffleThreadSliceSize_N =
|
||||
(CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) /
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3);
|
||||
|
||||
constexpr auto PostShuffleThreadSliceSize_M_N =
|
||||
Sequence<PostShuffleThreadSliceSize_M, PostShuffleThreadSliceSize_N>{};
|
||||
|
||||
// Welford
|
||||
constexpr auto post_shuffle_thread_desc_m_n =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{},
|
||||
Number<PostShuffleThreadSliceSize_M>{},
|
||||
Number<1>{},
|
||||
Number<PostShuffleThreadSliceSize_N>{}));
|
||||
|
||||
auto e_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
post_shuffle_thread_desc_m_n.GetElementSpaceSize());
|
||||
|
||||
using PostShuffleThreadClusterSize_M_N = Sequence<
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I1),
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3)>;
|
||||
|
||||
constexpr auto post_shuffle_thread_cluster_desc =
|
||||
make_cluster_descriptor(PostShuffleThreadClusterSize_M_N{}, Sequence<0, 1>{});
|
||||
|
||||
const auto post_shuffle_thread_cluster_idx =
|
||||
post_shuffle_thread_cluster_desc.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto post_shuffle_thread_data_idx_begin =
|
||||
post_shuffle_thread_cluster_idx * PostShuffleThreadSliceSize_M_N;
|
||||
|
||||
constexpr auto thread_welford_src_desc_m_k = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<PostShuffleThreadSliceSize_M>{}, Number<PostShuffleThreadSliceSize_N>{}));
|
||||
|
||||
constexpr auto thread_welford_dst_desc_m =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<PostShuffleThreadSliceSize_M>{}));
|
||||
|
||||
using ThreadwiseWelford = ThreadwiseWelford<AccDataType,
|
||||
decltype(thread_welford_src_desc_m_k),
|
||||
decltype(thread_welford_dst_desc_m)>;
|
||||
|
||||
using BlockwiseWelford = BlockwiseWelford<AccDataType,
|
||||
BlockSize,
|
||||
PostShuffleThreadClusterSize_M_N,
|
||||
Sequence<0, 1>,
|
||||
false>;
|
||||
|
||||
constexpr int num_shuffleM =
|
||||
MPerBlock / (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma);
|
||||
|
||||
constexpr int num_shuffleN =
|
||||
NPerBlock / (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma);
|
||||
|
||||
using mean_var_vgpr_type = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
thread_welford_dst_desc_m.GetElementSpaceSize()));
|
||||
|
||||
using welford_count_vgpr_type =
|
||||
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
|
||||
thread_welford_dst_desc_m.GetElementSpaceSize()));
|
||||
|
||||
Array<ThreadwiseWelford, num_shuffleM> threadwise_welfords;
|
||||
Array<mean_var_vgpr_type, num_shuffleM> mean_thread_bufs;
|
||||
Array<mean_var_vgpr_type, num_shuffleM> var_thread_bufs;
|
||||
Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
|
||||
|
||||
int max_count = PostShuffleThreadSliceSize_N * num_shuffleN;
|
||||
const auto nblock = mean_var_grid_desc_mblock_mperblock_nblock.GetLength(I2);
|
||||
|
||||
// tail block
|
||||
if(block_n_id % nblock == nblock - 1)
|
||||
{
|
||||
constexpr index_t NPerShuffleBlock =
|
||||
CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma;
|
||||
|
||||
int NPerBlockTail = NRaw - NPerBlock * (nblock - 1);
|
||||
int thread_max_len =
|
||||
PostShuffleThreadSliceSize_N * (post_shuffle_thread_cluster_idx[I1] + 1);
|
||||
int shuffle_step = 0;
|
||||
while(thread_max_len <= NPerBlockTail && shuffle_step < num_shuffleN)
|
||||
{
|
||||
++shuffle_step;
|
||||
thread_max_len += NPerShuffleBlock;
|
||||
}
|
||||
|
||||
int delta = 0;
|
||||
if(thread_max_len - NPerBlockTail > PostShuffleThreadSliceSize_N)
|
||||
delta = 0;
|
||||
else if(NPerBlockTail > thread_max_len)
|
||||
delta = PostShuffleThreadSliceSize_N;
|
||||
else
|
||||
delta = PostShuffleThreadSliceSize_N - thread_max_len + NPerBlockTail;
|
||||
|
||||
max_count = shuffle_step * PostShuffleThreadSliceSize_N + delta;
|
||||
}
|
||||
|
||||
// Initialize Welford
|
||||
static_for<0, num_shuffleM, 1>{}([&](auto i) {
|
||||
threadwise_welfords(i).max_count_ = max_count;
|
||||
mean_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
thread_welford_dst_desc_m.GetElementSpaceSize());
|
||||
|
||||
var_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
thread_welford_dst_desc_m.GetElementSpaceSize());
|
||||
|
||||
welford_count_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
|
||||
thread_welford_dst_desc_m.GetElementSpaceSize());
|
||||
|
||||
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
|
||||
mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
|
||||
var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
|
||||
welford_count_thread_bufs(i)(j) = 0;
|
||||
});
|
||||
});
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
// Run CShuffle + Store E + Welford threadwise
|
||||
int shuffleM_index = __builtin_amdgcn_readfirstlane(0);
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each thread shuffle data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(
|
||||
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
// Read LDS / Vmem + CDE elementwise operation
|
||||
cde_shuffle_block_copy_lds_and_global.RunRead(c_ds_desc_refs, c_ds_buf_refs);
|
||||
|
||||
// Store to Vmem, but keep data in Vgpr for Welford
|
||||
cde_shuffle_block_copy_lds_and_global.RunWriteAndStoreVgpr(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e_grid_buf),
|
||||
tie(post_shuffle_thread_desc_m_n),
|
||||
tie(e_thread_buf));
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
// move on Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
|
||||
c_ds_desc_refs, i + I1, cde_global_step);
|
||||
});
|
||||
|
||||
// move on E
|
||||
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
|
||||
}
|
||||
|
||||
// Threadwise welford
|
||||
auto& threadwise_welford = threadwise_welfords(shuffleM_index);
|
||||
auto& mean_thread_buf = mean_thread_bufs(shuffleM_index);
|
||||
auto& var_thread_buf = var_thread_bufs(shuffleM_index);
|
||||
|
||||
threadwise_welford.Run(e_thread_buf, mean_thread_buf, var_thread_buf);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto de_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
constexpr int shuffleMInc =
|
||||
de_global_step[I1] /
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
|
||||
I1);
|
||||
shuffleM_index = __builtin_amdgcn_readfirstlane(shuffleM_index + shuffleMInc);
|
||||
}
|
||||
});
|
||||
|
||||
// Blockwise welford and write out
|
||||
static_for<0, num_shuffleM, 1>{}([&](auto i) {
|
||||
auto& mean_thread_buf = mean_thread_bufs(i);
|
||||
auto& var_thread_buf = var_thread_bufs(i);
|
||||
auto& count_thread_buf = welford_count_thread_bufs(i);
|
||||
|
||||
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
|
||||
block_sync_lds();
|
||||
count_thread_buf(j) = threadwise_welfords(i).cur_count_;
|
||||
BlockwiseWelford::Run(mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j));
|
||||
});
|
||||
|
||||
if(post_shuffle_thread_cluster_idx[I1] == 0)
|
||||
{
|
||||
constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<PostShuffleThreadSliceSize_M>{}, I1));
|
||||
|
||||
constexpr int shuffleMPerBlock =
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
|
||||
I1);
|
||||
|
||||
auto mean_var_count_thread_copy_index = make_multi_index(
|
||||
block_m_id, // mblock
|
||||
shuffleMPerBlock * i + post_shuffle_thread_data_idx_begin[I0], // mperblock
|
||||
block_n_id); // nblock
|
||||
|
||||
auto mean_var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
EDataType,
|
||||
decltype(thread_welford_desc_I_m_I),
|
||||
decltype(mean_var_grid_desc_mblock_mperblock_nblock),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
|
||||
Sequence<0, 1, 2>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{mean_var_grid_desc_mblock_mperblock_nblock,
|
||||
mean_var_count_thread_copy_index,
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
mean_var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
|
||||
make_tuple(I0, I0, I0),
|
||||
mean_thread_buf,
|
||||
mean_var_grid_desc_mblock_mperblock_nblock,
|
||||
mean_grid_buf); // write mean
|
||||
|
||||
mean_var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
|
||||
make_tuple(I0, I0, I0),
|
||||
var_thread_buf,
|
||||
mean_var_grid_desc_mblock_mperblock_nblock,
|
||||
var_grid_buf); // write variance
|
||||
|
||||
// Stride of count is [0, 1]. Only the first row in count[0, 0:nblock] need
|
||||
// to be written.
|
||||
if(i == 0 && block_m_id == 0 && post_shuffle_thread_cluster_idx[I0] == 0)
|
||||
{
|
||||
auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
int32_t,
|
||||
int32_t,
|
||||
decltype(thread_welford_desc_I_m_I),
|
||||
decltype(count_grid_desc_mblock_mperblock_nblock),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
|
||||
Sequence<0, 1, 2>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
false>{count_grid_desc_mblock_mperblock_nblock,
|
||||
mean_var_count_thread_copy_index,
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
count_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
|
||||
make_tuple(I0, I0, I0),
|
||||
count_thread_buf,
|
||||
count_grid_desc_mblock_mperblock_nblock,
|
||||
welford_count_grid_buf); // write count
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
EDataType* p_welford_mean_grid;
|
||||
EDataType* p_welford_var_grid;
|
||||
int32_t* p_welford_count_grid;
|
||||
index_t NRaw;
|
||||
GemmMeanVarGridDesc_M_N gemm_mean_var_grid_desc_m_nblock;
|
||||
GemmCountGridDesc_M_N gemm_count_grid_desc_m_nblock;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,195 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ThisThreadBlock,
|
||||
typename BlockwiseGemmPipe>
|
||||
struct EpilogueCShuffle
|
||||
: EpilogueCShuffleBase<DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>
|
||||
{
|
||||
using Base = EpilogueCShuffleBase<
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>;
|
||||
|
||||
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
|
||||
using Base::GetCShuffleLDSDescriptor;
|
||||
using Base::GetVgprToLDSEpilogueDescriptor;
|
||||
using Base::I1;
|
||||
using Base::NumDTensor;
|
||||
|
||||
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
typename CThreadBuf,
|
||||
typename DsGridPointer,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
|
||||
__device__ static void Run(CThreadBuf& c_thread_buf,
|
||||
DsGridPointer p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
CDEElementwiseOperation& cde_element_op,
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id)
|
||||
{
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ds_grid[i],
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// C mapping in single thread.
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
BlockwiseGemmPipe::
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
// LDS buffer
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
|
||||
.GetElementSpaceSize());
|
||||
|
||||
// Thread transfer Vgpr to LDS
|
||||
auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor();
|
||||
|
||||
// Space Filling Curve Vgpr
|
||||
constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{};
|
||||
|
||||
// Space Filling Curve Vmem
|
||||
constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{};
|
||||
|
||||
// Block descriptor
|
||||
constexpr auto
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
GetCShuffleLDSDescriptor();
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// Thread transfer LDS to Vmem
|
||||
auto cde_shuffle_block_copy_lds_and_global =
|
||||
Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, EDataType>(
|
||||
c_ds_desc_refs,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id);
|
||||
|
||||
// tuple of reference to C/Ds tensor buffers
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
// CShuffle and Store
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(
|
||||
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block loads its C data from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
cde_shuffle_block_copy_lds_and_global.Run(
|
||||
c_ds_desc_refs,
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e_grid_buf));
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
// move on Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
|
||||
c_ds_desc_refs, i + I1, cde_global_step);
|
||||
});
|
||||
|
||||
// move on E
|
||||
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,253 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ThisThreadBlock,
|
||||
typename BlockwiseGemmPipe>
|
||||
struct EpilogueCShuffleBase
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr auto EShuffleBlockTransferScalarPerVector =
|
||||
CDEShuffleBlockTransferScalarPerVectors{}[I0];
|
||||
|
||||
using SpaceFillingCurveVgpr =
|
||||
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, BlockwiseGemmPipe::MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
1,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
1,
|
||||
1,
|
||||
BlockwiseGemmPipe::MAccVgprs>>;
|
||||
|
||||
using SpaceFillingCurveVmem = SpaceFillingCurve<
|
||||
Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma>>;
|
||||
|
||||
// *Caution Here repeat is shuffle repeat
|
||||
__device__ static constexpr auto
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
|
||||
I1,
|
||||
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
|
||||
|
||||
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetCShuffleLDSDescriptor()
|
||||
{
|
||||
// C mapping in single block
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
|
||||
BlockwiseGemmPipe::
|
||||
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
constexpr auto MWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I1);
|
||||
constexpr auto MSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I2);
|
||||
constexpr auto NWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I4);
|
||||
constexpr auto NThreadPerSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I5);
|
||||
constexpr auto MAccVgprs =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I6);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(),
|
||||
make_tuple(make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
|
||||
MWave, // MWave
|
||||
MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
|
||||
MAccVgprs)),
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
|
||||
NWave, // NWave
|
||||
NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
__device__ static auto GetVgprToLDSEpilogueDescriptor()
|
||||
{
|
||||
// C mapping in single block
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
|
||||
BlockwiseGemmPipe::
|
||||
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
constexpr auto MWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I1);
|
||||
constexpr auto MSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I2);
|
||||
constexpr auto NWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I4);
|
||||
constexpr auto NThreadPerSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I5);
|
||||
constexpr auto MAccVgprs =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I6);
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
BlockwiseGemmPipe::CalculateCThreadOriginDataIndex(I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
|
||||
.CalculateBottomIndex(make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
return ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(BlockwiseGemmPipe::
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()),
|
||||
decltype(GetCShuffleLDSDescriptor()),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{GetCShuffleLDSDescriptor(),
|
||||
make_multi_index(0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
0,
|
||||
n_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
}
|
||||
|
||||
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
typename InterDataType,
|
||||
typename CDsDescRefs,
|
||||
typename EGridDesc>
|
||||
__device__ static auto
|
||||
GetLDSToVmemEpilogueDescriptor(CDsDescRefs& c_ds_desc_refs,
|
||||
EGridDesc& e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
CDEElementwiseOperation& cde_element_op,
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id)
|
||||
{
|
||||
// tuple of starting index of C/Ds blockwise copy
|
||||
const auto idx_c_ds_block_begin = container_concat(
|
||||
make_tuple(make_multi_index(0, 0, 0, 0)),
|
||||
generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// blockwise copy which loads C from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
return ThreadGroupTensorSliceTransfer_v7r3<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
Tuple<EDataType>,
|
||||
CDsDescRefs,
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CDEElementwiseOperation, // ElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves *
|
||||
NPerWmma>, // BlockSliceLengths,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
|
||||
Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
|
||||
Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
|
||||
3, // SrcVectorDim,
|
||||
3, // DstVectorDim,
|
||||
CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
|
||||
EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
1,
|
||||
Tuple<InterDataType>>{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
|
||||
cde_element_op};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -315,8 +315,6 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
using Base::MakeDsGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
|
||||
|
||||
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
@@ -556,7 +554,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(AsGridPointer& p_as_grid,
|
||||
BsGridPointer& p_bs_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
@@ -565,7 +564,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
@@ -610,6 +610,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid,
|
||||
@@ -627,16 +628,20 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
b_scale_struct);
|
||||
b_scale_struct,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
// Wrapper function to have __global__ function in common
|
||||
// between gemm_universal, b_scale, ab_scale, etc.
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
__device__ static void
|
||||
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
AsGridPointer p_as_grid_splitk;
|
||||
@@ -663,7 +668,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -209,8 +209,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
using Base::MakeDsGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
|
||||
|
||||
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
@@ -533,7 +531,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(AsGridPointer& p_as_grid,
|
||||
BsGridPointer& p_bs_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
@@ -543,7 +542,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
@@ -593,6 +593,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid,
|
||||
@@ -610,16 +611,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
b_scale_struct);
|
||||
b_scale_struct,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
// NOTE: Wrapper function to have __global__ function in common
|
||||
// between gemm_universal, b_scale, ab_scale, etc.
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
__device__ static void
|
||||
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
AsGridPointer p_as_grid_splitk;
|
||||
@@ -647,7 +652,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -23,6 +23,8 @@
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -46,12 +48,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg);
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
@@ -262,9 +268,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
static_assert(!PermuteA, "PermuteA is not supported");
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
// if arch = gfx942
|
||||
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
@@ -539,23 +543,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
// *Caution Here repeat is shuffle repeat
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
|
||||
I1,
|
||||
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
|
||||
|
||||
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
}
|
||||
|
||||
using BlockwiseGemmPipe =
|
||||
remove_cvref_t<decltype(BlockGemmPipeline_Selector<BlkGemmPipelineVer,
|
||||
BlkGemmPipeSched,
|
||||
@@ -578,6 +565,46 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
NRepeat,
|
||||
KPack>())>;
|
||||
|
||||
// Used to create obj in global function and pass it to Run method
|
||||
using EpilogueCShuffle =
|
||||
EpilogueCShuffle<DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>;
|
||||
|
||||
using EpilogueWelfordCShuffle = EpilogueWelfordCShuffle<
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe,
|
||||
BlockSize>;
|
||||
|
||||
template <typename DEGridDesc>
|
||||
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
@@ -821,6 +848,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
|
||||
}
|
||||
|
||||
template <typename EpilogueType>
|
||||
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
@@ -838,7 +866,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
// LDS allocation for C shuffle in LDS
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
EpilogueType::
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
|
||||
@@ -867,6 +896,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename BScaleStruct,
|
||||
typename EpilogueArgument,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
@@ -887,7 +917,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id,
|
||||
const index_t& num_k_block_per_scale,
|
||||
BScaleStruct& b_scale_struct)
|
||||
BScaleStruct& b_scale_struct,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const auto as_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -903,16 +934,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ds_grid[i],
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
|
||||
|
||||
@@ -984,240 +1005,16 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
num_k_block_per_scale);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
// C mapping in single thread.
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
blockwise_gemm_pipeline
|
||||
.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
// C mapping in single block
|
||||
constexpr auto
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
|
||||
blockwise_gemm_pipeline
|
||||
.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
constexpr auto MWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I1);
|
||||
constexpr auto MSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I2);
|
||||
constexpr auto NWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I4);
|
||||
constexpr auto NThreadPerSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I5);
|
||||
constexpr auto MAccVgprs =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I6);
|
||||
|
||||
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
|
||||
.GetElementSpaceSize());
|
||||
|
||||
constexpr auto
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
|
||||
MWave, // MWave
|
||||
MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
|
||||
MAccVgprs)),
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
|
||||
NWave, // NWave
|
||||
NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{},
|
||||
Sequence<0, 1, 2, 6>{},
|
||||
Sequence<>{},
|
||||
Sequence<3, 4, 5>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(
|
||||
MRepeat, MWave, MSubGroup, MAccVgprs))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
|
||||
.CalculateBottomIndex(make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(
|
||||
NRepeat, NWave, NThreadPerSubGroup))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor
|
||||
.CalculateBottomIndex(make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
1, // vector write pixel
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
make_multi_index(0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
0,
|
||||
n_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of reference to C/Ds tensor buffers
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of starting index of C/Ds blockwise copy
|
||||
const auto idx_c_ds_block_begin = container_concat(
|
||||
make_tuple(make_multi_index(0, 0, 0, 0)),
|
||||
generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// blockwise copy which loads C from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
auto cde_shuffle_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
Tuple<EDataType>,
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CDEElementwiseOperation, // ElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
|
||||
Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
|
||||
Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
|
||||
3, // SrcVectorDim,
|
||||
3, // DstVectorDim,
|
||||
CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
|
||||
EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
|
||||
cde_element_op};
|
||||
|
||||
// space filling curve for local reg & global memory
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
1,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
1,
|
||||
1,
|
||||
MAccVgprs>>{};
|
||||
|
||||
// space filling curve for shuffled blockwise C in global mem
|
||||
constexpr auto sfc_cde_global =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(
|
||||
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block loads its C data from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
cde_shuffle_block_copy_lds_and_global.Run(
|
||||
c_ds_desc_refs,
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e_grid_buf));
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
// move on Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
|
||||
c_ds_desc_refs, i + I1, cde_global_step);
|
||||
});
|
||||
|
||||
// move on E
|
||||
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
epilogue_args.template Run<EGlobalMemoryDataOperation>(
|
||||
c_thread_buf,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user