Add support for direct store in epilogue and padding support for wave transfer without transpose (#3465)

- Add support for direct store in epilogue instead of cshuffle
 - Add padding support for wave transfer without transpose
 - Add wave transfer with interleaved layout to support direct store
 - Enable new functionalities on GEMMs
 - Add optional new functionality support for grouped convolution fwd
 - Add some fast instances for grouped convolution fwd with new functionalities (proper tuning needed)
This commit is contained in:
Enrico Degregori
2026-01-14 11:02:19 +01:00
committed by GitHub
parent 51027474af
commit 693ff3bbb3
20 changed files with 948 additions and 155 deletions

View File

@@ -59,6 +59,8 @@ struct EpilogueCShuffleBase
1,
CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma>>;
__device__ static constexpr bool IsLDSNeeded() { return true; }
// *Caution Here repeat is shuffle repeat
__device__ static constexpr auto
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()

View File

@@ -0,0 +1,145 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename DsDataType,
typename EDataType,
typename AccDataType,
index_t MRepeat,
index_t NRepeat,
typename CDEElementwiseOperation,
typename BlockwiseGemmPipe>
struct EpilogueDirectStore
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
__device__ static constexpr bool IsLDSNeeded() { return false; }
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename CThreadBuf,
typename DsGridPointer,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
__device__ static void Run(CThreadBuf& c_thread_buf,
DsGridPointer,
EDataType* p_e_grid,
void*,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
CDEElementwiseOperation& cde_element_op,
const index_t& block_m_id,
const index_t& block_n_id)
{
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// C mapping in single thread.
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
BlockwiseGemmPipe::
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
// C mapping in single block
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
BlockwiseGemmPipe::
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr auto MWave =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I1);
constexpr auto MSubGroup =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I2);
constexpr auto NWave =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I4);
constexpr auto NThreadPerSubGroup =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I5);
constexpr auto MAccVgprs =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I6);
// origin
const auto c_thread_mtx_on_block =
BlockwiseGemmPipe::CalculateCThreadOriginDataIndex(I0, I0);
const auto m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_grid_idx =
m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
make_multi_index(c_thread_mtx_on_block[I0]));
const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_grid_idx =
n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
make_multi_index(c_thread_mtx_on_block[I1]));
// E grid descriptor
const auto c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
transform_tensor_descriptor(
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_freeze_transform(block_m_id),
make_unmerge_transform(make_tuple(Number<MRepeat>{},
Number<MWave>{},
Number<MSubGroup>{},
Number<MAccVgprs>{})),
make_freeze_transform(block_n_id),
make_unmerge_transform(make_tuple(
Number<NWave>{}, Number<NThreadPerSubGroup>{}, Number<NRepeat>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<4, 5, 3>{}));
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
EDataType,
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
decltype(c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
CDEElementwiseOperation,
Sequence<MRepeat, I1, I1, NRepeat, I1, I1, MAccVgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
3,
NRepeat, // VectorSize
EGlobalMemoryDataOperation,
1,
false>{c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
make_multi_index(m_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
n_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3]),
cde_element_op};
c_thread_copy.Run(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
make_tuple(I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
e_grid_buf);
}
};
} // namespace ck

View File

@@ -77,26 +77,79 @@ struct ABTransferWaveTiles
static constexpr index_t KRepeat_ = KPerBlock / (KWaves_ * KPack);
static constexpr index_t MNRepeat_ = MNPerBlock / (MNWaves_ * MNPerWmma);
template <bool PadMN, bool PadK, typename GridDescriptorBase>
__host__ __device__ static auto PadGridDescriptor(GridDescriptorBase& base_desc,
index_t sizeMN,
index_t MNPad,
index_t sizeK,
index_t KPad,
index_t,
index_t)
{
if constexpr(PadMN && PadK)
{
// pad both MN and K
return transform_tensor_descriptor(
base_desc,
make_tuple(make_right_pad_transform(sizeMN, MNPad - sizeMN),
make_right_pad_transform(sizeK, KPad - sizeK)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(PadMN && !PadK)
{
// pad MN, but not K
return transform_tensor_descriptor(
base_desc,
make_tuple(make_right_pad_transform(sizeMN, MNPad - sizeMN),
make_pass_through_transform(sizeK)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(!PadMN && PadK)
{
// pad K, but not MN
return transform_tensor_descriptor(
base_desc,
make_tuple(make_pass_through_transform(sizeMN),
make_right_pad_transform(sizeK, KPad - sizeK)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad MN or K
return base_desc;
}
}
template <bool PadMN, bool PadK, typename GridDescriptorBase>
__host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc,
index_t sizeMN,
index_t,
index_t MNPad,
index_t sizeK,
index_t,
index_t KPad,
index_t,
index_t)
{
// Notes: padding is currently not supported
static_assert(!PadMN && !PadK, "padding is currently not supported");
// Notes: padding is currently not supported with transpose
static_assert(!((PadMN || PadK) && ABDoTranspose),
"padding is currently not supported with transpose");
const index_t MN_grid = !PadMN ? sizeMN : MNPad;
const index_t K_grid = !PadK ? sizeK : KPad;
const auto base_desc_padded =
PadGridDescriptor<PadMN, PadK>(base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0);
// Divide the base descriptor MN_K into tiles
const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor(
base_desc,
base_desc_padded,
make_tuple(
make_unmerge_transform(make_tuple(
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{}), Number<MNPerWmma>{})),
make_unmerge_transform(make_tuple(math::integer_divide_ceil(sizeK, Number<KPack>{}),
Number<KPack>{}))),
math::integer_divide_ceil(MN_grid, Number<MNPerWmma>{}), Number<MNPerWmma>{})),
make_unmerge_transform(make_tuple(
math::integer_divide_ceil(K_grid, Number<KPack>{}), Number<KPack>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
@@ -112,9 +165,9 @@ struct ABTransferWaveTiles
transform_tensor_descriptor(
ab_grid_desc_mntiles_ktiles,
make_tuple(make_pass_through_transform(
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{})),
math::integer_divide_ceil(MN_grid, Number<MNPerWmma>{})),
make_pass_through_transform(
math::integer_divide_ceil(sizeK, Number<KPack>{})),
math::integer_divide_ceil(K_grid, Number<KPack>{})),
make_pass_through_transform(Number<MNPerWmma>{}),
make_unmerge_transform(
make_tuple(Number<MNKRow>{}, Number<KPack / MNKRow>{}))),
@@ -127,8 +180,8 @@ struct ABTransferWaveTiles
ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
make_tuple(
make_pass_through_transform(
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{})),
make_pass_through_transform(math::integer_divide_ceil(sizeK, Number<KPack>{})),
math::integer_divide_ceil(MN_grid, Number<MNPerWmma>{})),
make_pass_through_transform(math::integer_divide_ceil(K_grid, Number<KPack>{})),
make_pass_through_transform(Number<MNPerWmma>{}),
make_pass_through_transform(Number<MNKRow>{}),
make_freeze_transform(I0)),
@@ -143,9 +196,9 @@ struct ABTransferWaveTiles
transform_tensor_descriptor(
ab_grid_desc_mntiles_ktiles,
make_tuple(make_pass_through_transform(
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{})),
math::integer_divide_ceil(MN_grid, Number<MNPerWmma>{})),
make_pass_through_transform(
math::integer_divide_ceil(sizeK, Number<KPack>{})),
math::integer_divide_ceil(K_grid, Number<KPack>{})),
make_unmerge_transform(
make_tuple(Number<MNKRow>{}, Number<MNPerWmma / MNKRow>{})),
make_pass_through_transform(Number<KPack>{})),
@@ -157,8 +210,8 @@ struct ABTransferWaveTiles
ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
make_tuple(
make_pass_through_transform(
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{})),
make_pass_through_transform(math::integer_divide_ceil(sizeK, Number<KPack>{})),
math::integer_divide_ceil(MN_grid, Number<MNPerWmma>{})),
make_pass_through_transform(math::integer_divide_ceil(K_grid, Number<KPack>{})),
make_pass_through_transform(Number<MNKRow>{}),
make_freeze_transform(I0),
make_pass_through_transform(Number<KPack>{})),

View File

@@ -0,0 +1,275 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/utility/amd_address_space.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp"
#include "ck/utility/math.hpp"
namespace ck {
template <typename ABLayout,
typename ABMajorLayout,
typename LDSTypeAB,
index_t BlockSize,
index_t MNPerBlock,
index_t KPerBlock,
index_t MNPerWmma,
index_t KPack,
index_t ABK1Value,
index_t WaveSize,
index_t MNWaves_Gemm>
struct ABTransferWaveTilesInterleave : ABTransferWaveTiles<ABLayout,
ABMajorLayout,
LDSTypeAB,
BlockSize,
MNPerBlock,
KPerBlock,
MNPerWmma,
KPack,
ABK1Value,
WaveSize>
{
using Base = ABTransferWaveTiles<ABLayout,
ABMajorLayout,
LDSTypeAB,
BlockSize,
MNPerBlock,
KPerBlock,
MNPerWmma,
KPack,
ABK1Value,
WaveSize>;
using Base::ABDoTranspose;
using Base::I0;
using Base::I1;
using Base::I2;
using Base::I3;
using Base::MNKRow;
using Base::GetBlockLaneIdx;
using Base::GetBlockStep;
using Base::GetGridLaneIdx;
using Base::GetWaveIdx;
using Base::PadGridDescriptor;
using typename Base::ThisThreadBlock;
static constexpr auto I4 = Number<4>{};
static_assert(!ABDoTranspose, "wave tile interleaved transfer does not support transpose yet");
using Base::KRepeat_;
using Base::KWaves_;
using Base::MNRepeat_;
static constexpr index_t MNWaves_Grid = MNWaves_Gemm;
static constexpr index_t KWaves_Grid = (BlockSize / WaveSize) / MNWaves_Gemm;
static constexpr index_t KRepeat_Grid = KPerBlock / (KWaves_Grid * KPack);
static constexpr index_t MNRepeat_Grid = MNPerBlock / (MNWaves_Grid * MNPerWmma);
template <bool PadMN, bool PadK, typename GridDescriptorBase>
__host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc,
index_t sizeMN,
index_t MNPad,
index_t sizeK,
index_t KPad,
index_t,
index_t)
{
const auto base_desc_padded = Base::template PadGridDescriptor<PadMN, PadK>(
base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0);
const index_t MN_grid = !PadMN ? sizeMN : MNPad;
const index_t K_grid = !PadK ? sizeK : KPad;
// Divide the base descriptor MN_K into tiles
const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor(
base_desc_padded,
make_tuple(make_unmerge_transform(make_tuple(
math::integer_divide_ceil(MN_grid, Number<MNPerWmma * MNRepeat_Grid>{}),
Number<MNPerWmma * MNRepeat_Grid>{})),
make_unmerge_transform(make_tuple(
math::integer_divide_ceil(K_grid, Number<KPack>{}), Number<KPack>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
// The distinction is needed to get the same global indices for both layouts
// Divide each tile in 2 16x8 subtile
// MNTiles - KTiles - MNKRow - LaneLocal - VectorSize
// MNKRow = 0-1
// LaneLocal = 0-15
// VectorSize must be 8
if constexpr(!ABDoTranspose)
{
const auto ab_grid_desc_mntiles_ktiles_mnrepeat = transform_tensor_descriptor(
ab_grid_desc_mntiles_ktiles,
make_tuple(
make_pass_through_transform(
math::integer_divide_ceil(MN_grid, Number<MNPerWmma * MNRepeat_Grid>{})),
make_pass_through_transform(math::integer_divide_ceil(K_grid, Number<KPack>{})),
make_unmerge_transform(
make_tuple(Number<MNPerWmma>{}, Number<MNRepeat_Grid>{})),
make_pass_through_transform(Number<KPack>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<3, 2>{}, Sequence<4>{}));
const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 =
transform_tensor_descriptor(
ab_grid_desc_mntiles_ktiles_mnrepeat,
make_tuple(make_pass_through_transform(math::integer_divide_ceil(
MN_grid, Number<MNPerWmma * MNRepeat_Grid>{})),
make_pass_through_transform(
math::integer_divide_ceil(K_grid, Number<KPack>{})),
make_pass_through_transform(Number<MNRepeat_Grid>{}),
make_pass_through_transform(Number<MNPerWmma>{}),
make_unmerge_transform(
make_tuple(Number<MNKRow>{}, Number<KPack / MNKRow>{}))),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4, 5>{}));
// Freeze VectorSize to first element of the loading chunk (for convenience)
// Swap MNPerWmma and MNKRow for consistency with transpose descriptor
return transform_tensor_descriptor(
ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
make_tuple(
make_pass_through_transform(
math::integer_divide_ceil(MN_grid, Number<MNPerWmma * MNRepeat_Grid>{})),
make_pass_through_transform(math::integer_divide_ceil(K_grid, Number<KPack>{})),
make_pass_through_transform(Number<MNRepeat_Grid>{}),
make_pass_through_transform(Number<MNPerWmma>{}),
make_pass_through_transform(Number<MNKRow>{}),
make_freeze_transform(I0)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<3>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<>{}));
}
}
__device__ static constexpr auto GetBlockDescriptor()
{
// LDS memory layouts:
// lanes within tiles stored contiguously in chunks of 8 elements
// tiles are then stored first in K dimension
// MNTiles - KTiles - MNKRow - LaneLocal - VectorSize
const auto a_grid_desc_mraw_kraw = [&]() {
return make_naive_tensor_descriptor(
make_tuple(Number<MNWaves_Grid>{},
Number<KRepeat_Grid * KWaves_Grid>{},
Number<MNRepeat_Grid>{},
Number<MNKRow>{},
Number<MNPerWmma>{},
Number<ABK1Value>{}),
make_tuple(Number<KPack * MNPerWmma * KWaves_Grid * KRepeat_Grid>{},
Number<KPack * MNPerWmma>{},
Number<KPack * MNPerWmma * KWaves_Grid * KRepeat_Grid * MNWaves_Grid>{},
Number<ABK1Value * MNPerWmma>{},
Number<ABK1Value>{},
I1));
}();
// Freeze VectorSize to first element of the chunk (for convenience)
return transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(Number<MNWaves_Grid>{}),
make_pass_through_transform(Number<KRepeat_Grid * KWaves_Grid>{}),
make_pass_through_transform(Number<MNRepeat_Grid>{}),
make_pass_through_transform(Number<MNKRow>{}),
make_pass_through_transform(Number<MNPerWmma>{}),
make_freeze_transform(I0)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<>{}));
}
template <typename GridDescriptor,
typename BlockDescriptor,
typename ABsDataType,
typename ABElementwiseOperation,
index_t GlobalBufferNum>
__device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor,
BlockDescriptor& block_descriptor,
ABElementwiseOperation& ab_element_op,
const index_t block_mn_id,
const index_t)
{
// Note: GlobalBufferNum is currently not used but it will be needed
// once we add other pipelines. It is currently needed only for
// consistency with the thread tiles approach
static_assert(GlobalBufferNum == 1, "single global buffer is only supported");
constexpr index_t NumABTensor = ABsDataType::Size();
static_assert(NumABTensor == 1, "multiAB currently not supported");
using ABDataType = remove_cvref_t<tuple_element_t<0, ABsDataType>>;
const auto wave_idx = GetWaveIdx();
index_t wave_idK = wave_idx[I1];
index_t wave_idMN = wave_idx[I0];
const auto grid_lane_id = Base::template GetGridLaneIdx<ABDataType>();
index_t lane_group_grid = grid_lane_id[I0];
index_t lane_local_id_grid = grid_lane_id[I1];
const auto block_lane_id = GetBlockLaneIdx();
index_t lane_group_block = block_lane_id[I0];
index_t lane_local_id_block = block_lane_id[I1];
constexpr index_t MNRepeatRatio = MNRepeat_Grid / MNRepeat_;
return ThreadGroupTransferGlobal<decltype(grid_descriptor[I0]),
BlockDescriptor,
ABDataType,
ABDataType,
ABElementwiseOperation,
Sequence<I1, KRepeat_, MNRepeat_, I1, I1>,
Sequence<I1, KWaves_, I1, I1, I1>,
Sequence<I0, I1, I2, I3, I4>,
ABK1Value,
ABDoTranspose>(
grid_descriptor[I0],
block_descriptor,
make_multi_index(block_mn_id * MNWaves_Grid + wave_idMN / MNRepeatRatio,
wave_idK * KRepeat_Grid,
(wave_idMN % MNRepeatRatio) * MNRepeat_,
lane_group_grid,
lane_local_id_grid),
make_multi_index(wave_idMN / MNRepeatRatio,
wave_idK * KRepeat_,
(wave_idMN % MNRepeatRatio) * MNRepeat_,
lane_group_block,
lane_local_id_block),
ab_element_op);
}
__device__ static constexpr auto GetBlockStep()
{
// Grid descriptor step (MoveSrcSliceWindow)
return make_multi_index(I0, KWaves_ * KRepeat_, I0, I0, I0);
}
};
} // namespace ck

View File

@@ -177,7 +177,8 @@ template <typename ALayout,
bool PermuteA,
bool PermuteB,
bool IsBPreShuffled = false,
bool ForceThreadTileTransfer = false>
bool ForceThreadTileTransfer = false,
bool IsFusedKernel = false>
struct GridwiseGemm_wmma_cshuffle_v3
: GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
@@ -231,7 +232,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
PermuteA,
PermuteB,
IsBPreShuffled,
ForceThreadTileTransfer>
ForceThreadTileTransfer,
IsFusedKernel>
{
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
@@ -285,7 +287,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
PermuteA,
PermuteB,
IsBPreShuffled,
ForceThreadTileTransfer>;
ForceThreadTileTransfer,
IsFusedKernel>;
using Base::I0;
using Base::I1;

View File

@@ -15,6 +15,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
@@ -24,6 +25,7 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"
@@ -50,13 +52,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
using EpilogueType =
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
auto epilogue_args = EpilogueType{};
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args);
@@ -167,7 +175,8 @@ template <typename ALayout,
bool PermuteA,
bool PermuteB,
bool IsBPreShuffled = false,
bool ForceThreadTileTransfer = false> // only needed for convolution (limitation)
bool ForceThreadTileTransfer = false, // only needed for convolution (limitation)
bool IsFusedKernel = false>
struct GridwiseGemm_wmma_cshuffle_v3_base
{
@@ -182,6 +191,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
using LDSTypeA =
typename std::conditional<(NumATensor > 1),
@@ -232,30 +242,44 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
return 1;
}();
static constexpr index_t WaveSize =
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
.wave_size;
// Limitations of the current implementation:
// - no multiAB
// - GemmSpecialization Default
// - pipeline v1 because v3 is buggy (fixed in batched gemm gemm implementation)
// AK1Value == 8 is not really a limitation but a requirement for the method so
// it will stay
// - GemmSpecialization Default with transpose
#ifdef __gfx12__
static constexpr bool IsAWaveTransferApplicable =
!ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
!is_same_v<ALayout, tensor_layout::gemm::RowMajor>) ||
is_same_v<ALayout, tensor_layout::gemm::RowMajor>) &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled;
static constexpr bool IsBWaveTransferApplicable =
!ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
!is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) ||
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
static constexpr bool IsWaveTileInterleavedFitting =
(NPerBlock / NPerWmma / NRepeat) * (KPerBlock / KPack) >= (BlockSize / WaveSize);
// We need to investigate if it makes sense to remove cshuffle for smaller types
// Currently we use direct store for NRepeat equal to 4 or 8. For 16 bit type we use at
// least buffer store 64 bit for 16 contiguous threads -> 128 bytes in total (full cache line)
static constexpr bool UseDirectStore = is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
sizeof(ComputeTypeB) == 2 && sizeof(EDataType) == 2 &&
NumDTensor == 0 && (NRepeat == 4 || NRepeat == 8) &&
!IsFusedKernel && IsWaveTileInterleavedFitting;
#else
static constexpr bool IsAWaveTransferApplicable = false;
static constexpr bool IsBWaveTransferApplicable = false;
static constexpr bool UseDirectStore = false;
#endif
static constexpr index_t WaveSize =
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
.wave_size;
static constexpr bool UseBlockPaddingA =
ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4;
using ATransfer = typename std::conditional<
@@ -293,7 +317,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
static constexpr bool UseBlockPaddingB =
BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4;
using BTransfer = typename std::conditional<
IsBPreShuffled,
ABTransferThreadTilesPreShuffle<BLayout,
@@ -309,16 +332,29 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
BThreadTransferSrcResetCoordinateAfterRun>,
typename std::conditional<
IsBWaveTransferApplicable,
ABTransferWaveTiles<BLayout,
tensor_layout::gemm::ColumnMajor,
LDSTypeB,
BlockSize,
NPerBlock,
KPerBlock,
NPerWmma,
KPack,
BK1Value,
WaveSize>,
typename std::conditional<
UseDirectStore,
ABTransferWaveTilesInterleave<BLayout,
tensor_layout::gemm::ColumnMajor,
LDSTypeB,
BlockSize,
NPerBlock,
KPerBlock,
NPerWmma,
KPack,
BK1Value,
WaveSize,
NPerBlock / NPerWmma / NRepeat>,
ABTransferWaveTiles<BLayout,
tensor_layout::gemm::ColumnMajor,
LDSTypeB,
BlockSize,
NPerBlock,
KPerBlock,
NPerWmma,
KPack,
BK1Value,
WaveSize>>::type,
ABTransferThreadTiles<BLayout,
tensor_layout::gemm::ColumnMajor,
LDSTypeB,
@@ -490,6 +526,19 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
Number<NumATensor>{});
}
template <typename GridDescBase>
__device__ static auto MakeAGridDescriptor_AK0_M_AK1(const GridDescBase& base_desc)
{
const auto M = base_desc.GetLength(I0);
const auto K = base_desc.GetLength(I1);
const auto AK0 = K / AK1Value;
constexpr bool padM = false;
constexpr bool padK = false;
return ATransfer::template MakeGridDescriptor<padM, padK>(base_desc, M, M, K, K, 0, AK0);
}
__host__ __device__ static auto
MakeBsGridDescriptor_BK0_N_BK1(const index_t K,
const index_t KPad,
@@ -516,6 +565,19 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
Number<NumBTensor>{});
}
template <typename GridDescBase>
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(const GridDescBase& base_desc)
{
const auto N = base_desc.GetLength(I0);
const auto K = base_desc.GetLength(I1);
const auto BK0 = K / BK1Value;
constexpr bool padN = false;
constexpr bool padK = false;
return BTransfer::template MakeGridDescriptor<padN, padK>(base_desc, N, N, K, K, 0, BK0);
}
__host__ __device__ static constexpr auto MakeAWmmaTileDescriptor()
{
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
@@ -594,8 +656,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
#endif
}
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
@@ -679,6 +739,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
ThisThreadBlock,
BlockwiseGemmPipe>;
using EpilogueDirectStore = EpilogueDirectStore<DsDataType,
EDataType,
AccDataType,
MRepeat,
NRepeat,
CDEElementwiseOperation,
BlockwiseGemmPipe>;
using EpilogueWelfordCShuffle = EpilogueWelfordCShuffle<
DsDataType,
EDataType,
@@ -1000,18 +1068,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
max_lds_align)
: 0;
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
EpilogueType::
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
if constexpr(EpilogueType::IsLDSNeeded())
{
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
EpilogueType::
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
constexpr auto c_block_size =
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.GetElementSpaceSize();
constexpr auto c_block_size =
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.GetElementSpaceSize();
return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize +
b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize),
c_block_size * sizeof(CShuffleDataType));
return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize +
b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize),
c_block_size * sizeof(CShuffleDataType));
}
else
{
return a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize +
b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize;
}
}
template <index_t numElements, typename Type>
@@ -1148,7 +1224,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
num_k_block_main_loop,
num_k_block_per_scale);
// shuffle C and write out
// Epilogue:
// - CShuffle / direct store
// - Multiple Ds
// - Fused operations
epilogue_args.template Run<EGlobalMemoryDataOperation>(
c_thread_buf,
p_ds_grid,