Add support for direct store in epilogue and padding support for wave transfer without transpose (#3465)

- Add support for direct store in epilogue instead of cshuffle
 - Add padding support for wave transfer without transpose
 - Add wave transfer with interleaved layout to support direct store
 - Enable new functionalities on GEMMs
 - Add optional new functionality support for grouped convolution fwd
 - Add some fast instances for grouped convolution fwd with new functionalities (proper tuning needed)
This commit is contained in:
Enrico Degregori
2026-01-14 11:02:19 +01:00
committed by GitHub
parent 51027474af
commit 693ff3bbb3
20 changed files with 948 additions and 155 deletions

View File

@@ -60,8 +60,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
using EpilogueType =
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
@@ -84,7 +90,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
});
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
auto epilogue_args = EpilogueType{};
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_as_grid_shift,

View File

@@ -46,8 +46,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
std::is_same_v<c_data_type, ck::bhalf_t>)))
{
#endif
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
using EpilogueType =
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
// The normal approach to batching would be to increase the grid size by just stretching out
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
// functions not directly using the Z dimension for other calculations. As it turns out, k
@@ -86,7 +92,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
});
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
auto epilogue_args = EpilogueType{};
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_as_grid_shift,

View File

@@ -188,7 +188,10 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
PermuteB,
false, // IsBPreShuffled
false, // ForceThreadTileTransfer
true>; // IsFusedKernel
using ReduceTrait = ReduceTrait_<ReduceAccDataType,
ReducePtrsGlobal,

View File

@@ -273,7 +273,10 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
PermuteB,
false,
false,
true>;
// Welford 2nd part kernel
template <typename DoPads, index_t MPerTile, index_t NPerTile>

View File

@@ -187,7 +187,10 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
PermuteB,
false,
false,
true>;
using ReduceTrait = ReduceTrait_<ReduceAccDataType,
ReducePtrsGlobal,

View File

@@ -48,8 +48,8 @@ namespace {
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*/
template <typename GridwiseGemm,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffset, // For Batch (group) and N
@@ -63,8 +63,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_conv_fwd_wmma_cshuffle_v3(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const AGridDesc_M_K a_grid_desc_m_k,
const BGridDesc_N_K b_grid_desc_n_k,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
@@ -82,13 +82,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
__shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>()];
using EpilogueType =
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ char p_shared[LDS_size];
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
auto epilogue_args = EpilogueType{};
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffset,
@@ -110,8 +123,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
#else
ignore = karg;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = a_grid_desc_m_k;
ignore = b_grid_desc_n_k;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = compute_ptr_offset_of_batch;
@@ -187,6 +200,7 @@ template <index_t NDimSpatial,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
bool UseThreadTileTransfer = true,
typename AComputeDataType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>,
@@ -289,9 +303,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
NPerBlock / ClusterLengthNPerBlock>{};
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
@@ -307,21 +319,11 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
const auto M = in_gemmm_gemmk_desc.GetLength(I0);
const auto K = in_gemmm_gemmk_desc.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(in_gemmm_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
@@ -337,16 +339,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(wei_gemmn_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return wei_gemmn_gemmk_desc;
}
template <typename ELay>
@@ -364,15 +357,21 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
// Force MN padding on the output tensor. This allows to use Gemm default or only K padding
// and remove some instructions in the hot loop (same approach used for gemm universal).
if constexpr(CTranspose)
{
constexpr auto matrix_padder_trans =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{NPerBlock, MPerBlock, KPerBlock};
return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
constexpr auto matrix_padder_MN_padding_trans =
MatrixPadder<GemmSpecialization::MNPadding, index_t, index_t, index_t>{
NPerBlock, MPerBlock, KPerBlock};
return matrix_padder_MN_padding_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
}
else
{
return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
constexpr auto matrix_padder_MN_padding =
MatrixPadder<GemmSpecialization::MNPadding, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
return matrix_padder_MN_padding.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
}
}
@@ -452,10 +451,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
BlkGemmPipelineVer,
AComputeDataType,
BComputeDataType,
false, // PermuteA
false, // PermuteB
false, // IsBPreShuffled
true>; // ForceThreadTileTransfer
false, // PermuteA
false, // PermuteB
false, // IsBPreShuffled
UseThreadTileTransfer>; // ForceThreadTileTransfer
// TODO: Previously available template param DoElementwiseBeforeCShuffle!
@@ -529,7 +528,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
false, // PermuteB
false, // PermuteA
false, // IsBPreShuffled
true>; // ForceThreadTileTransfer
true>; // ForceThreadTileTransfer (always force it because of limitations in the transfer)
using GridwiseGemmCTranspose =
std::conditional_t<CTranspose, GridwiseGemmSwappedParams, GridwiseGemm>;
@@ -626,10 +625,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
I1>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
dummy_conv_to_gemm_transformer))>;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
// Argument
struct Argument : public BaseArgument
@@ -695,10 +694,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
a_grid_desc_ak0_m_ak1_{
MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_bk0_n_bk1_{
MakeBGridDescriptor_BK0_N_BK1<BLayout>(conv_to_gemm_transformer_)},
a_grid_desc_m_k_{MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_n_k_{MakeBGridDescriptor_N_K<BLayout>(conv_to_gemm_transformer_)},
compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op},
@@ -798,8 +795,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
}
{
const index_t GemmM = a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t GemmN = b_grid_desc_bk0_n_bk1_.GetLength(I1);
const index_t GemmM = a_grid_desc_m_k_.GetLength(I0);
const index_t GemmN = b_grid_desc_n_k_.GetLength(I0);
const auto MBlock = CTranspose ? GridwiseGemmCTranspose::CalculateMBlock(GemmN)
: GridwiseGemmCTranspose::CalculateMBlock(GemmM);
const auto NBlock = CTranspose ? GridwiseGemmCTranspose::CalculateNBlock(GemmM)
@@ -883,7 +880,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
is_same_v<ALayout, tensor_layout::convolution::NDHWGC>)
{
size_as_buffers[i] =
(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() +
(a_grid_desc_m_k_.GetElementSpaceSize() +
(num_group_ - NumGroupsToMerge) * (a_g_n_c_wis_strides_[0])) *
sizeof(ADataType_single) / GridwiseGemm::APackedSize;
}
@@ -891,13 +888,13 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
{
if(CTranspose && a_g_n_c_wis_lengths_[I1] > 1)
{
size_as_buffers[i] = (a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() +
size_as_buffers[i] = (a_grid_desc_m_k_.GetElementSpaceSize() +
(eff_num_group - 1) * (a_g_n_c_wis_strides_[0])) *
sizeof(ADataType_single) / GridwiseGemm::APackedSize;
}
else
{
size_as_buffers[i] = a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() *
size_as_buffers[i] = a_grid_desc_m_k_.GetElementSpaceSize() *
eff_num_group * sizeof(ADataType_single) /
GridwiseGemm::APackedSize;
}
@@ -914,7 +911,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_single = remove_cvref_t<tuple_element_t<i.value, GemmBsDataType>>;
size_bs_buffers[i] = b_grid_desc_bk0_n_bk1_.GetElementSpaceSize() * eff_num_group *
size_bs_buffers[i] = b_grid_desc_n_k_.GetElementSpaceSize() * eff_num_group *
sizeof(BDataType_single) / GridwiseGemm::BPackedSize;
});
@@ -961,8 +958,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
void Print() const
{
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
std::cout << "B[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[BK0, N, BK1]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
@@ -998,8 +995,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
@@ -1048,10 +1045,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
const index_t GemmK =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const index_t GemmM = arg.a_grid_desc_m_k_.GetLength(I0);
const index_t GemmN = arg.b_grid_desc_n_k_.GetLength(I0);
const index_t GemmK = arg.a_grid_desc_m_k_.GetLength(I1);
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
@@ -1193,8 +1189,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
dim3(BlockSize),
0,
gemm_arg_,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_n_k_,
arg.a_grid_desc_m_k_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_groups_,
@@ -1210,8 +1206,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
dim3(BlockSize),
0,
gemm_arg,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_n_k_,
arg.a_grid_desc_m_k_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_groups_,
@@ -1291,8 +1287,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
dim3(BlockSize),
0,
gemm_arg_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_groups_,
@@ -1308,8 +1304,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
dim3(BlockSize),
0,
gemm_arg,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_groups_,
@@ -1327,8 +1323,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
{
const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3<
GridwiseGemmCTranspose,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_N_K,
DeviceOp::AGridDesc_M_K,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffset,
@@ -1342,8 +1338,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
{
const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3<
GridwiseGemm,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_M_K,
DeviceOp::BGridDesc_N_K,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffset,
@@ -1985,10 +1981,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
}
// check Gridwise GEMM
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
const index_t GemmK =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const index_t GemmM = arg.a_grid_desc_m_k_.GetLength(I0);
const index_t GemmN = arg.b_grid_desc_n_k_.GetLength(I0);
const index_t GemmK = arg.a_grid_desc_m_k_.GetLength(I1);
if constexpr(CTranspose)
{

View File

@@ -66,8 +66,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const CDEElementwiseOperation cde_element_op)
{
#if(defined(__gfx11__) || defined(__gfx12__))
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ uint8_t p_shared[LDS_size];
const auto gemm_desc_ptr =
@@ -150,7 +154,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
gemm_desc_ptr[group_id].StrideE,
1);
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
auto epilogue_args = EpilogueType{};
constexpr TailNumber TailNum = TailNumber::Full;
if(has_main_k_block_loop)

View File

@@ -41,8 +41,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const index_t group_count)
{
#if(defined(__gfx11__) || defined(__gfx12__))
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ char p_shared[LDS_size];
const index_t block_id = get_block_1d_id();
@@ -89,13 +93,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]);
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
auto epilogue_args = EpilogueType{};
GridwiseGemm::template Run<HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
Block2CTileMap,
typename GridwiseGemm::EpilogueCShuffle,
EpilogueType,
1, // Block2CTileMap MBlock index
2 // Block2CTileMap NBlock index
>(static_cast<void*>(p_shared),

View File

@@ -59,6 +59,8 @@ struct EpilogueCShuffleBase
1,
CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma>>;
__device__ static constexpr bool IsLDSNeeded() { return true; }
// *Caution Here repeat is shuffle repeat
__device__ static constexpr auto
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()

View File

@@ -0,0 +1,145 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename DsDataType,
typename EDataType,
typename AccDataType,
index_t MRepeat,
index_t NRepeat,
typename CDEElementwiseOperation,
typename BlockwiseGemmPipe>
struct EpilogueDirectStore
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
__device__ static constexpr bool IsLDSNeeded() { return false; }
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename CThreadBuf,
typename DsGridPointer,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
__device__ static void Run(CThreadBuf& c_thread_buf,
DsGridPointer,
EDataType* p_e_grid,
void*,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
CDEElementwiseOperation& cde_element_op,
const index_t& block_m_id,
const index_t& block_n_id)
{
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// C mapping in single thread.
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
BlockwiseGemmPipe::
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
// C mapping in single block
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
BlockwiseGemmPipe::
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr auto MWave =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I1);
constexpr auto MSubGroup =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I2);
constexpr auto NWave =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I4);
constexpr auto NThreadPerSubGroup =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I5);
constexpr auto MAccVgprs =
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.GetLength(I6);
// origin
const auto c_thread_mtx_on_block =
BlockwiseGemmPipe::CalculateCThreadOriginDataIndex(I0, I0);
const auto m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_grid_idx =
m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
make_multi_index(c_thread_mtx_on_block[I0]));
const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_grid_idx =
n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
make_multi_index(c_thread_mtx_on_block[I1]));
// E grid descriptor
const auto c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
transform_tensor_descriptor(
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_freeze_transform(block_m_id),
make_unmerge_transform(make_tuple(Number<MRepeat>{},
Number<MWave>{},
Number<MSubGroup>{},
Number<MAccVgprs>{})),
make_freeze_transform(block_n_id),
make_unmerge_transform(make_tuple(
Number<NWave>{}, Number<NThreadPerSubGroup>{}, Number<NRepeat>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<4, 5, 3>{}));
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
EDataType,
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
decltype(c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
CDEElementwiseOperation,
Sequence<MRepeat, I1, I1, NRepeat, I1, I1, MAccVgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
3,
NRepeat, // VectorSize
EGlobalMemoryDataOperation,
1,
false>{c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
make_multi_index(m_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
n_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3]),
cde_element_op};
c_thread_copy.Run(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
make_tuple(I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
e_grid_buf);
}
};
} // namespace ck

View File

@@ -77,26 +77,79 @@ struct ABTransferWaveTiles
static constexpr index_t KRepeat_ = KPerBlock / (KWaves_ * KPack);
static constexpr index_t MNRepeat_ = MNPerBlock / (MNWaves_ * MNPerWmma);
template <bool PadMN, bool PadK, typename GridDescriptorBase>
__host__ __device__ static auto PadGridDescriptor(GridDescriptorBase& base_desc,
index_t sizeMN,
index_t MNPad,
index_t sizeK,
index_t KPad,
index_t,
index_t)
{
if constexpr(PadMN && PadK)
{
// pad both MN and K
return transform_tensor_descriptor(
base_desc,
make_tuple(make_right_pad_transform(sizeMN, MNPad - sizeMN),
make_right_pad_transform(sizeK, KPad - sizeK)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(PadMN && !PadK)
{
// pad MN, but not K
return transform_tensor_descriptor(
base_desc,
make_tuple(make_right_pad_transform(sizeMN, MNPad - sizeMN),
make_pass_through_transform(sizeK)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(!PadMN && PadK)
{
// pad K, but not MN
return transform_tensor_descriptor(
base_desc,
make_tuple(make_pass_through_transform(sizeMN),
make_right_pad_transform(sizeK, KPad - sizeK)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad MN or K
return base_desc;
}
}
template <bool PadMN, bool PadK, typename GridDescriptorBase>
__host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc,
index_t sizeMN,
index_t,
index_t MNPad,
index_t sizeK,
index_t,
index_t KPad,
index_t,
index_t)
{
// Notes: padding is currently not supported
static_assert(!PadMN && !PadK, "padding is currently not supported");
// Notes: padding is currently not supported with transpose
static_assert(!((PadMN || PadK) && ABDoTranspose),
"padding is currently not supported with transpose");
const index_t MN_grid = !PadMN ? sizeMN : MNPad;
const index_t K_grid = !PadK ? sizeK : KPad;
const auto base_desc_padded =
PadGridDescriptor<PadMN, PadK>(base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0);
// Divide the base descriptor MN_K into tiles
const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor(
base_desc,
base_desc_padded,
make_tuple(
make_unmerge_transform(make_tuple(
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{}), Number<MNPerWmma>{})),
make_unmerge_transform(make_tuple(math::integer_divide_ceil(sizeK, Number<KPack>{}),
Number<KPack>{}))),
math::integer_divide_ceil(MN_grid, Number<MNPerWmma>{}), Number<MNPerWmma>{})),
make_unmerge_transform(make_tuple(
math::integer_divide_ceil(K_grid, Number<KPack>{}), Number<KPack>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
@@ -112,9 +165,9 @@ struct ABTransferWaveTiles
transform_tensor_descriptor(
ab_grid_desc_mntiles_ktiles,
make_tuple(make_pass_through_transform(
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{})),
math::integer_divide_ceil(MN_grid, Number<MNPerWmma>{})),
make_pass_through_transform(
math::integer_divide_ceil(sizeK, Number<KPack>{})),
math::integer_divide_ceil(K_grid, Number<KPack>{})),
make_pass_through_transform(Number<MNPerWmma>{}),
make_unmerge_transform(
make_tuple(Number<MNKRow>{}, Number<KPack / MNKRow>{}))),
@@ -127,8 +180,8 @@ struct ABTransferWaveTiles
ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
make_tuple(
make_pass_through_transform(
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{})),
make_pass_through_transform(math::integer_divide_ceil(sizeK, Number<KPack>{})),
math::integer_divide_ceil(MN_grid, Number<MNPerWmma>{})),
make_pass_through_transform(math::integer_divide_ceil(K_grid, Number<KPack>{})),
make_pass_through_transform(Number<MNPerWmma>{}),
make_pass_through_transform(Number<MNKRow>{}),
make_freeze_transform(I0)),
@@ -143,9 +196,9 @@ struct ABTransferWaveTiles
transform_tensor_descriptor(
ab_grid_desc_mntiles_ktiles,
make_tuple(make_pass_through_transform(
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{})),
math::integer_divide_ceil(MN_grid, Number<MNPerWmma>{})),
make_pass_through_transform(
math::integer_divide_ceil(sizeK, Number<KPack>{})),
math::integer_divide_ceil(K_grid, Number<KPack>{})),
make_unmerge_transform(
make_tuple(Number<MNKRow>{}, Number<MNPerWmma / MNKRow>{})),
make_pass_through_transform(Number<KPack>{})),
@@ -157,8 +210,8 @@ struct ABTransferWaveTiles
ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
make_tuple(
make_pass_through_transform(
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{})),
make_pass_through_transform(math::integer_divide_ceil(sizeK, Number<KPack>{})),
math::integer_divide_ceil(MN_grid, Number<MNPerWmma>{})),
make_pass_through_transform(math::integer_divide_ceil(K_grid, Number<KPack>{})),
make_pass_through_transform(Number<MNKRow>{}),
make_freeze_transform(I0),
make_pass_through_transform(Number<KPack>{})),

View File

@@ -0,0 +1,275 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/utility/amd_address_space.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp"
#include "ck/utility/math.hpp"
namespace ck {
template <typename ABLayout,
typename ABMajorLayout,
typename LDSTypeAB,
index_t BlockSize,
index_t MNPerBlock,
index_t KPerBlock,
index_t MNPerWmma,
index_t KPack,
index_t ABK1Value,
index_t WaveSize,
index_t MNWaves_Gemm>
struct ABTransferWaveTilesInterleave : ABTransferWaveTiles<ABLayout,
ABMajorLayout,
LDSTypeAB,
BlockSize,
MNPerBlock,
KPerBlock,
MNPerWmma,
KPack,
ABK1Value,
WaveSize>
{
using Base = ABTransferWaveTiles<ABLayout,
ABMajorLayout,
LDSTypeAB,
BlockSize,
MNPerBlock,
KPerBlock,
MNPerWmma,
KPack,
ABK1Value,
WaveSize>;
using Base::ABDoTranspose;
using Base::I0;
using Base::I1;
using Base::I2;
using Base::I3;
using Base::MNKRow;
using Base::GetBlockLaneIdx;
using Base::GetBlockStep;
using Base::GetGridLaneIdx;
using Base::GetWaveIdx;
using Base::PadGridDescriptor;
using typename Base::ThisThreadBlock;
static constexpr auto I4 = Number<4>{};
static_assert(!ABDoTranspose, "wave tile interleaved transfer does not support transpose yet");
using Base::KRepeat_;
using Base::KWaves_;
using Base::MNRepeat_;
static constexpr index_t MNWaves_Grid = MNWaves_Gemm;
static constexpr index_t KWaves_Grid = (BlockSize / WaveSize) / MNWaves_Gemm;
static constexpr index_t KRepeat_Grid = KPerBlock / (KWaves_Grid * KPack);
static constexpr index_t MNRepeat_Grid = MNPerBlock / (MNWaves_Grid * MNPerWmma);
template <bool PadMN, bool PadK, typename GridDescriptorBase>
__host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc,
index_t sizeMN,
index_t MNPad,
index_t sizeK,
index_t KPad,
index_t,
index_t)
{
const auto base_desc_padded = Base::template PadGridDescriptor<PadMN, PadK>(
base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0);
const index_t MN_grid = !PadMN ? sizeMN : MNPad;
const index_t K_grid = !PadK ? sizeK : KPad;
// Divide the base descriptor MN_K into tiles
const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor(
base_desc_padded,
make_tuple(make_unmerge_transform(make_tuple(
math::integer_divide_ceil(MN_grid, Number<MNPerWmma * MNRepeat_Grid>{}),
Number<MNPerWmma * MNRepeat_Grid>{})),
make_unmerge_transform(make_tuple(
math::integer_divide_ceil(K_grid, Number<KPack>{}), Number<KPack>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
// The distinction is needed to get the same global indices for both layouts
// Divide each tile in 2 16x8 subtile
// MNTiles - KTiles - MNKRow - LaneLocal - VectorSize
// MNKRow = 0-1
// LaneLocal = 0-15
// VectorSize must be 8
if constexpr(!ABDoTranspose)
{
const auto ab_grid_desc_mntiles_ktiles_mnrepeat = transform_tensor_descriptor(
ab_grid_desc_mntiles_ktiles,
make_tuple(
make_pass_through_transform(
math::integer_divide_ceil(MN_grid, Number<MNPerWmma * MNRepeat_Grid>{})),
make_pass_through_transform(math::integer_divide_ceil(K_grid, Number<KPack>{})),
make_unmerge_transform(
make_tuple(Number<MNPerWmma>{}, Number<MNRepeat_Grid>{})),
make_pass_through_transform(Number<KPack>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<3, 2>{}, Sequence<4>{}));
const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 =
transform_tensor_descriptor(
ab_grid_desc_mntiles_ktiles_mnrepeat,
make_tuple(make_pass_through_transform(math::integer_divide_ceil(
MN_grid, Number<MNPerWmma * MNRepeat_Grid>{})),
make_pass_through_transform(
math::integer_divide_ceil(K_grid, Number<KPack>{})),
make_pass_through_transform(Number<MNRepeat_Grid>{}),
make_pass_through_transform(Number<MNPerWmma>{}),
make_unmerge_transform(
make_tuple(Number<MNKRow>{}, Number<KPack / MNKRow>{}))),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4, 5>{}));
// Freeze VectorSize to first element of the loading chunk (for convenience)
// Swap MNPerWmma and MNKRow for consistency with transpose descriptor
return transform_tensor_descriptor(
ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
make_tuple(
make_pass_through_transform(
math::integer_divide_ceil(MN_grid, Number<MNPerWmma * MNRepeat_Grid>{})),
make_pass_through_transform(math::integer_divide_ceil(K_grid, Number<KPack>{})),
make_pass_through_transform(Number<MNRepeat_Grid>{}),
make_pass_through_transform(Number<MNPerWmma>{}),
make_pass_through_transform(Number<MNKRow>{}),
make_freeze_transform(I0)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<3>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<>{}));
}
}
__device__ static constexpr auto GetBlockDescriptor()
{
// LDS memory layouts:
// lanes within tiles stored contiguously in chunks of 8 elements
// tiles are then stored first in K dimension
// MNTiles - KTiles - MNKRow - LaneLocal - VectorSize
const auto a_grid_desc_mraw_kraw = [&]() {
return make_naive_tensor_descriptor(
make_tuple(Number<MNWaves_Grid>{},
Number<KRepeat_Grid * KWaves_Grid>{},
Number<MNRepeat_Grid>{},
Number<MNKRow>{},
Number<MNPerWmma>{},
Number<ABK1Value>{}),
make_tuple(Number<KPack * MNPerWmma * KWaves_Grid * KRepeat_Grid>{},
Number<KPack * MNPerWmma>{},
Number<KPack * MNPerWmma * KWaves_Grid * KRepeat_Grid * MNWaves_Grid>{},
Number<ABK1Value * MNPerWmma>{},
Number<ABK1Value>{},
I1));
}();
// Freeze VectorSize to first element of the chunk (for convenience)
return transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(Number<MNWaves_Grid>{}),
make_pass_through_transform(Number<KRepeat_Grid * KWaves_Grid>{}),
make_pass_through_transform(Number<MNRepeat_Grid>{}),
make_pass_through_transform(Number<MNKRow>{}),
make_pass_through_transform(Number<MNPerWmma>{}),
make_freeze_transform(I0)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<>{}));
}
template <typename GridDescriptor,
typename BlockDescriptor,
typename ABsDataType,
typename ABElementwiseOperation,
index_t GlobalBufferNum>
__device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor,
BlockDescriptor& block_descriptor,
ABElementwiseOperation& ab_element_op,
const index_t block_mn_id,
const index_t)
{
// Note: GlobalBufferNum is currently not used but it will be needed
// once we add other pipelines. It is currently needed only for
// consistency with the thread tiles approach
static_assert(GlobalBufferNum == 1, "single global buffer is only supported");
constexpr index_t NumABTensor = ABsDataType::Size();
static_assert(NumABTensor == 1, "multiAB currently not supported");
using ABDataType = remove_cvref_t<tuple_element_t<0, ABsDataType>>;
const auto wave_idx = GetWaveIdx();
index_t wave_idK = wave_idx[I1];
index_t wave_idMN = wave_idx[I0];
const auto grid_lane_id = Base::template GetGridLaneIdx<ABDataType>();
index_t lane_group_grid = grid_lane_id[I0];
index_t lane_local_id_grid = grid_lane_id[I1];
const auto block_lane_id = GetBlockLaneIdx();
index_t lane_group_block = block_lane_id[I0];
index_t lane_local_id_block = block_lane_id[I1];
constexpr index_t MNRepeatRatio = MNRepeat_Grid / MNRepeat_;
return ThreadGroupTransferGlobal<decltype(grid_descriptor[I0]),
BlockDescriptor,
ABDataType,
ABDataType,
ABElementwiseOperation,
Sequence<I1, KRepeat_, MNRepeat_, I1, I1>,
Sequence<I1, KWaves_, I1, I1, I1>,
Sequence<I0, I1, I2, I3, I4>,
ABK1Value,
ABDoTranspose>(
grid_descriptor[I0],
block_descriptor,
make_multi_index(block_mn_id * MNWaves_Grid + wave_idMN / MNRepeatRatio,
wave_idK * KRepeat_Grid,
(wave_idMN % MNRepeatRatio) * MNRepeat_,
lane_group_grid,
lane_local_id_grid),
make_multi_index(wave_idMN / MNRepeatRatio,
wave_idK * KRepeat_,
(wave_idMN % MNRepeatRatio) * MNRepeat_,
lane_group_block,
lane_local_id_block),
ab_element_op);
}
__device__ static constexpr auto GetBlockStep()
{
// Grid descriptor step (MoveSrcSliceWindow)
return make_multi_index(I0, KWaves_ * KRepeat_, I0, I0, I0);
}
};
} // namespace ck

View File

@@ -177,7 +177,8 @@ template <typename ALayout,
bool PermuteA,
bool PermuteB,
bool IsBPreShuffled = false,
bool ForceThreadTileTransfer = false>
bool ForceThreadTileTransfer = false,
bool IsFusedKernel = false>
struct GridwiseGemm_wmma_cshuffle_v3
: GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
@@ -231,7 +232,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
PermuteA,
PermuteB,
IsBPreShuffled,
ForceThreadTileTransfer>
ForceThreadTileTransfer,
IsFusedKernel>
{
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
@@ -285,7 +287,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
PermuteA,
PermuteB,
IsBPreShuffled,
ForceThreadTileTransfer>;
ForceThreadTileTransfer,
IsFusedKernel>;
using Base::I0;
using Base::I1;

View File

@@ -15,6 +15,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
@@ -24,6 +25,7 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"
@@ -50,13 +52,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
using EpilogueType =
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
auto epilogue_args = EpilogueType{};
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args);
@@ -167,7 +175,8 @@ template <typename ALayout,
bool PermuteA,
bool PermuteB,
bool IsBPreShuffled = false,
bool ForceThreadTileTransfer = false> // only needed for convolution (limitation)
bool ForceThreadTileTransfer = false, // only needed for convolution (limitation)
bool IsFusedKernel = false>
struct GridwiseGemm_wmma_cshuffle_v3_base
{
@@ -182,6 +191,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
using LDSTypeA =
typename std::conditional<(NumATensor > 1),
@@ -232,30 +242,44 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
return 1;
}();
static constexpr index_t WaveSize =
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
.wave_size;
// Limitations of the current implementation:
// - no multiAB
// - GemmSpecialization Default
// - pipeline v1 because v3 is buggy (fixed in batched gemm gemm implementation)
// AK1Value == 8 is not really a limitation but a requirement for the method so
// it will stay
// - GemmSpecialization Default with transpose
#ifdef __gfx12__
static constexpr bool IsAWaveTransferApplicable =
!ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
!is_same_v<ALayout, tensor_layout::gemm::RowMajor>) ||
is_same_v<ALayout, tensor_layout::gemm::RowMajor>) &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled;
static constexpr bool IsBWaveTransferApplicable =
!ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
!is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) ||
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
static constexpr bool IsWaveTileInterleavedFitting =
(NPerBlock / NPerWmma / NRepeat) * (KPerBlock / KPack) >= (BlockSize / WaveSize);
// We need to investigate if it makes sense to remove cshuffle for smaller types
// Currently we use direct store for NRepeat equal to 4 or 8. For 16 bit type we use at
// least buffer store 64 bit for 16 contiguous threads -> 128 bytes in total (full cache line)
static constexpr bool UseDirectStore = is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
sizeof(ComputeTypeB) == 2 && sizeof(EDataType) == 2 &&
NumDTensor == 0 && (NRepeat == 4 || NRepeat == 8) &&
!IsFusedKernel && IsWaveTileInterleavedFitting;
#else
static constexpr bool IsAWaveTransferApplicable = false;
static constexpr bool IsBWaveTransferApplicable = false;
static constexpr bool UseDirectStore = false;
#endif
static constexpr index_t WaveSize =
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
.wave_size;
static constexpr bool UseBlockPaddingA =
ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4;
using ATransfer = typename std::conditional<
@@ -293,7 +317,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
static constexpr bool UseBlockPaddingB =
BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4;
using BTransfer = typename std::conditional<
IsBPreShuffled,
ABTransferThreadTilesPreShuffle<BLayout,
@@ -309,16 +332,29 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
BThreadTransferSrcResetCoordinateAfterRun>,
typename std::conditional<
IsBWaveTransferApplicable,
ABTransferWaveTiles<BLayout,
tensor_layout::gemm::ColumnMajor,
LDSTypeB,
BlockSize,
NPerBlock,
KPerBlock,
NPerWmma,
KPack,
BK1Value,
WaveSize>,
typename std::conditional<
UseDirectStore,
ABTransferWaveTilesInterleave<BLayout,
tensor_layout::gemm::ColumnMajor,
LDSTypeB,
BlockSize,
NPerBlock,
KPerBlock,
NPerWmma,
KPack,
BK1Value,
WaveSize,
NPerBlock / NPerWmma / NRepeat>,
ABTransferWaveTiles<BLayout,
tensor_layout::gemm::ColumnMajor,
LDSTypeB,
BlockSize,
NPerBlock,
KPerBlock,
NPerWmma,
KPack,
BK1Value,
WaveSize>>::type,
ABTransferThreadTiles<BLayout,
tensor_layout::gemm::ColumnMajor,
LDSTypeB,
@@ -490,6 +526,19 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
Number<NumATensor>{});
}
template <typename GridDescBase>
__device__ static auto MakeAGridDescriptor_AK0_M_AK1(const GridDescBase& base_desc)
{
const auto M = base_desc.GetLength(I0);
const auto K = base_desc.GetLength(I1);
const auto AK0 = K / AK1Value;
constexpr bool padM = false;
constexpr bool padK = false;
return ATransfer::template MakeGridDescriptor<padM, padK>(base_desc, M, M, K, K, 0, AK0);
}
__host__ __device__ static auto
MakeBsGridDescriptor_BK0_N_BK1(const index_t K,
const index_t KPad,
@@ -516,6 +565,19 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
Number<NumBTensor>{});
}
template <typename GridDescBase>
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(const GridDescBase& base_desc)
{
const auto N = base_desc.GetLength(I0);
const auto K = base_desc.GetLength(I1);
const auto BK0 = K / BK1Value;
constexpr bool padN = false;
constexpr bool padK = false;
return BTransfer::template MakeGridDescriptor<padN, padK>(base_desc, N, N, K, K, 0, BK0);
}
__host__ __device__ static constexpr auto MakeAWmmaTileDescriptor()
{
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
@@ -594,8 +656,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
#endif
}
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
@@ -679,6 +739,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
ThisThreadBlock,
BlockwiseGemmPipe>;
using EpilogueDirectStore = EpilogueDirectStore<DsDataType,
EDataType,
AccDataType,
MRepeat,
NRepeat,
CDEElementwiseOperation,
BlockwiseGemmPipe>;
using EpilogueWelfordCShuffle = EpilogueWelfordCShuffle<
DsDataType,
EDataType,
@@ -1000,18 +1068,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
max_lds_align)
: 0;
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
EpilogueType::
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
if constexpr(EpilogueType::IsLDSNeeded())
{
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
EpilogueType::
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
constexpr auto c_block_size =
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.GetElementSpaceSize();
constexpr auto c_block_size =
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.GetElementSpaceSize();
return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize +
b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize),
c_block_size * sizeof(CShuffleDataType));
return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize +
b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize),
c_block_size * sizeof(CShuffleDataType));
}
else
{
return a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize +
b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize;
}
}
template <index_t numElements, typename Type>
@@ -1148,7 +1224,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
num_k_block_main_loop,
num_k_block_per_scale);
// shuffle C and write out
// Epilogue:
// - CShuffle / direct store
// - Multiple Ds
// - Fused operations
epilogue_args.template Run<EGlobalMemoryDataOperation>(
c_thread_buf,
p_ds_grid,