Add support for direct store in epilogue and padding support for wave transfer without transpose (#3465)

- Add support for direct store in epilogue instead of cshuffle
 - Add padding support for wave transfer without transpose
 - Add wave transfer with interleaved layout to support direct store
 - Enable new functionalities on GEMMs
 - Add optional new functionality support for grouped convolution fwd
 - Add some fast instances for grouped convolution fwd with new functionalities (proper tuning needed)
This commit is contained in:
Enrico Degregori
2026-01-14 11:02:19 +01:00
committed by GitHub
parent 51027474af
commit 693ff3bbb3
20 changed files with 948 additions and 155 deletions

View File

@@ -60,8 +60,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
using EpilogueType =
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
@@ -84,7 +90,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
});
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
auto epilogue_args = EpilogueType{};
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_as_grid_shift,

View File

@@ -46,8 +46,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
std::is_same_v<c_data_type, ck::bhalf_t>)))
{
#endif
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
using EpilogueType =
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
// The normal approach to batching would be to increase the grid size by just stretching out
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
// functions not directly using the Z dimension for other calculations. As it turns out, k
@@ -86,7 +92,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
});
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
auto epilogue_args = EpilogueType{};
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_as_grid_shift,

View File

@@ -188,7 +188,10 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
PermuteB,
false, // IsBPreShuffled
false, // ForceThreadTileTransfer
true>; // IsFusedKernel
using ReduceTrait = ReduceTrait_<ReduceAccDataType,
ReducePtrsGlobal,

View File

@@ -273,7 +273,10 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
PermuteB,
false,
false,
true>;
// Welford 2nd part kernel
template <typename DoPads, index_t MPerTile, index_t NPerTile>

View File

@@ -187,7 +187,10 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
PermuteB,
false,
false,
true>;
using ReduceTrait = ReduceTrait_<ReduceAccDataType,
ReducePtrsGlobal,

View File

@@ -48,8 +48,8 @@ namespace {
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*/
template <typename GridwiseGemm,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffset, // For Batch (group) and N
@@ -63,8 +63,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_conv_fwd_wmma_cshuffle_v3(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const AGridDesc_M_K a_grid_desc_m_k,
const BGridDesc_N_K b_grid_desc_n_k,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
@@ -82,13 +82,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
__shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>()];
using EpilogueType =
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ char p_shared[LDS_size];
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
auto epilogue_args = EpilogueType{};
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffset,
@@ -110,8 +123,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
#else
ignore = karg;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = a_grid_desc_m_k;
ignore = b_grid_desc_n_k;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = compute_ptr_offset_of_batch;
@@ -187,6 +200,7 @@ template <index_t NDimSpatial,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
bool UseThreadTileTransfer = true,
typename AComputeDataType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>,
@@ -289,9 +303,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
NPerBlock / ClusterLengthNPerBlock>{};
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
@@ -307,21 +319,11 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
const auto M = in_gemmm_gemmk_desc.GetLength(I0);
const auto K = in_gemmm_gemmk_desc.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(in_gemmm_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
@@ -337,16 +339,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(wei_gemmn_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return wei_gemmn_gemmk_desc;
}
template <typename ELay>
@@ -364,15 +357,21 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
// Force MN padding on the output tensor. This allows to use Gemm default or only K padding
// and remove some instructions in the hot loop (same approach used for gemm universal).
if constexpr(CTranspose)
{
constexpr auto matrix_padder_trans =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{NPerBlock, MPerBlock, KPerBlock};
return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
constexpr auto matrix_padder_MN_padding_trans =
MatrixPadder<GemmSpecialization::MNPadding, index_t, index_t, index_t>{
NPerBlock, MPerBlock, KPerBlock};
return matrix_padder_MN_padding_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
}
else
{
return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
constexpr auto matrix_padder_MN_padding =
MatrixPadder<GemmSpecialization::MNPadding, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
return matrix_padder_MN_padding.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
}
}
@@ -452,10 +451,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
BlkGemmPipelineVer,
AComputeDataType,
BComputeDataType,
false, // PermuteA
false, // PermuteB
false, // IsBPreShuffled
true>; // ForceThreadTileTransfer
false, // PermuteA
false, // PermuteB
false, // IsBPreShuffled
UseThreadTileTransfer>; // ForceThreadTileTransfer
// TODO: Previously available template param DoElementwiseBeforeCShuffle!
@@ -529,7 +528,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
false, // PermuteB
false, // PermuteA
false, // IsBPreShuffled
true>; // ForceThreadTileTransfer
true>; // ForceThreadTileTransfer (always force it because of limitations in the transfer)
using GridwiseGemmCTranspose =
std::conditional_t<CTranspose, GridwiseGemmSwappedParams, GridwiseGemm>;
@@ -626,10 +625,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
I1>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
dummy_conv_to_gemm_transformer))>;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
// Argument
struct Argument : public BaseArgument
@@ -695,10 +694,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
a_grid_desc_ak0_m_ak1_{
MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_bk0_n_bk1_{
MakeBGridDescriptor_BK0_N_BK1<BLayout>(conv_to_gemm_transformer_)},
a_grid_desc_m_k_{MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_n_k_{MakeBGridDescriptor_N_K<BLayout>(conv_to_gemm_transformer_)},
compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op},
@@ -798,8 +795,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
}
{
const index_t GemmM = a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t GemmN = b_grid_desc_bk0_n_bk1_.GetLength(I1);
const index_t GemmM = a_grid_desc_m_k_.GetLength(I0);
const index_t GemmN = b_grid_desc_n_k_.GetLength(I0);
const auto MBlock = CTranspose ? GridwiseGemmCTranspose::CalculateMBlock(GemmN)
: GridwiseGemmCTranspose::CalculateMBlock(GemmM);
const auto NBlock = CTranspose ? GridwiseGemmCTranspose::CalculateNBlock(GemmM)
@@ -883,7 +880,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
is_same_v<ALayout, tensor_layout::convolution::NDHWGC>)
{
size_as_buffers[i] =
(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() +
(a_grid_desc_m_k_.GetElementSpaceSize() +
(num_group_ - NumGroupsToMerge) * (a_g_n_c_wis_strides_[0])) *
sizeof(ADataType_single) / GridwiseGemm::APackedSize;
}
@@ -891,13 +888,13 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
{
if(CTranspose && a_g_n_c_wis_lengths_[I1] > 1)
{
size_as_buffers[i] = (a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() +
size_as_buffers[i] = (a_grid_desc_m_k_.GetElementSpaceSize() +
(eff_num_group - 1) * (a_g_n_c_wis_strides_[0])) *
sizeof(ADataType_single) / GridwiseGemm::APackedSize;
}
else
{
size_as_buffers[i] = a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() *
size_as_buffers[i] = a_grid_desc_m_k_.GetElementSpaceSize() *
eff_num_group * sizeof(ADataType_single) /
GridwiseGemm::APackedSize;
}
@@ -914,7 +911,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_single = remove_cvref_t<tuple_element_t<i.value, GemmBsDataType>>;
size_bs_buffers[i] = b_grid_desc_bk0_n_bk1_.GetElementSpaceSize() * eff_num_group *
size_bs_buffers[i] = b_grid_desc_n_k_.GetElementSpaceSize() * eff_num_group *
sizeof(BDataType_single) / GridwiseGemm::BPackedSize;
});
@@ -961,8 +958,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
void Print() const
{
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
std::cout << "B[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[BK0, N, BK1]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
@@ -998,8 +995,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
@@ -1048,10 +1045,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
const index_t GemmK =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const index_t GemmM = arg.a_grid_desc_m_k_.GetLength(I0);
const index_t GemmN = arg.b_grid_desc_n_k_.GetLength(I0);
const index_t GemmK = arg.a_grid_desc_m_k_.GetLength(I1);
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
@@ -1193,8 +1189,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
dim3(BlockSize),
0,
gemm_arg_,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_n_k_,
arg.a_grid_desc_m_k_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_groups_,
@@ -1210,8 +1206,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
dim3(BlockSize),
0,
gemm_arg,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_n_k_,
arg.a_grid_desc_m_k_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_groups_,
@@ -1291,8 +1287,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
dim3(BlockSize),
0,
gemm_arg_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_groups_,
@@ -1308,8 +1304,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
dim3(BlockSize),
0,
gemm_arg,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_groups_,
@@ -1327,8 +1323,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
{
const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3<
GridwiseGemmCTranspose,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_N_K,
DeviceOp::AGridDesc_M_K,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffset,
@@ -1342,8 +1338,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
{
const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3<
GridwiseGemm,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_M_K,
DeviceOp::BGridDesc_N_K,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffset,
@@ -1985,10 +1981,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
}
// check Gridwise GEMM
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
const index_t GemmK =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const index_t GemmM = arg.a_grid_desc_m_k_.GetLength(I0);
const index_t GemmN = arg.b_grid_desc_n_k_.GetLength(I0);
const index_t GemmK = arg.a_grid_desc_m_k_.GetLength(I1);
if constexpr(CTranspose)
{

View File

@@ -66,8 +66,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const CDEElementwiseOperation cde_element_op)
{
#if(defined(__gfx11__) || defined(__gfx12__))
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ uint8_t p_shared[LDS_size];
const auto gemm_desc_ptr =
@@ -150,7 +154,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
gemm_desc_ptr[group_id].StrideE,
1);
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
auto epilogue_args = EpilogueType{};
constexpr TailNumber TailNum = TailNumber::Full;
if(has_main_k_block_loop)

View File

@@ -41,8 +41,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const index_t group_count)
{
#if(defined(__gfx11__) || defined(__gfx12__))
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ char p_shared[LDS_size];
const index_t block_id = get_block_1d_id();
@@ -89,13 +93,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]);
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
auto epilogue_args = EpilogueType{};
GridwiseGemm::template Run<HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
Block2CTileMap,
typename GridwiseGemm::EpilogueCShuffle,
EpilogueType,
1, // Block2CTileMap MBlock index
2 // Block2CTileMap NBlock index
>(static_cast<void*>(p_shared),