mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add support for direct store in epilogue and padding support for wave transfer without transpose (#3465)
- Add support for direct store in epilogue instead of cshuffle - Add padding support for wave transfer without transpose - Add wave transfer with interleaved layout to support direct store - Enable new functionalities on GEMMs - Add optional new functionality support for grouped convolution fwd - Add some fast instances for grouped convolution fwd with new functionalities (proper tuning needed)
This commit is contained in:
@@ -60,8 +60,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const long_index_t c_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
|
||||
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
using EpilogueType =
|
||||
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
@@ -84,7 +90,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
|
||||
});
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
auto epilogue_args = EpilogueType{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_shift,
|
||||
|
||||
@@ -46,8 +46,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
std::is_same_v<c_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
using EpilogueType =
|
||||
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
// The normal approach to batching would be to increase the grid size by just stretching out
|
||||
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
|
||||
// functions not directly using the Z dimension for other calculations. As it turns out, k
|
||||
@@ -86,7 +92,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
|
||||
});
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
auto epilogue_args = EpilogueType{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_shift,
|
||||
|
||||
@@ -188,7 +188,10 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
PermuteB,
|
||||
false, // IsBPreShuffled
|
||||
false, // ForceThreadTileTransfer
|
||||
true>; // IsFusedKernel
|
||||
|
||||
using ReduceTrait = ReduceTrait_<ReduceAccDataType,
|
||||
ReducePtrsGlobal,
|
||||
|
||||
@@ -273,7 +273,10 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
PermuteB,
|
||||
false,
|
||||
false,
|
||||
true>;
|
||||
|
||||
// Welford 2nd part kernel
|
||||
template <typename DoPads, index_t MPerTile, index_t NPerTile>
|
||||
|
||||
@@ -187,7 +187,10 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
PermuteB,
|
||||
false,
|
||||
false,
|
||||
true>;
|
||||
|
||||
using ReduceTrait = ReduceTrait_<ReduceAccDataType,
|
||||
ReducePtrsGlobal,
|
||||
|
||||
@@ -48,8 +48,8 @@ namespace {
|
||||
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
|
||||
*/
|
||||
template <typename GridwiseGemm,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename AGridDesc_M_K,
|
||||
typename BGridDesc_N_K,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename ComputePtrOffset, // For Batch (group) and N
|
||||
@@ -63,8 +63,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_grouped_conv_fwd_wmma_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument karg,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const AGridDesc_M_K a_grid_desc_m_k,
|
||||
const BGridDesc_N_K b_grid_desc_n_k,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -82,13 +82,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
__shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>()];
|
||||
using EpilogueType =
|
||||
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
auto epilogue_args = EpilogueType{};
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
|
||||
|
||||
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ComputePtrOffset,
|
||||
@@ -110,8 +123,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = a_grid_desc_m_k;
|
||||
ignore = b_grid_desc_n_k;
|
||||
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
@@ -187,6 +200,7 @@ template <index_t NDimSpatial,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
bool UseThreadTileTransfer = true,
|
||||
typename AComputeDataType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
@@ -289,9 +303,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
NPerBlock / ClusterLengthNPerBlock>{};
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
|
||||
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
@@ -307,21 +319,11 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
|
||||
const auto M = in_gemmm_gemmk_desc.GetLength(I0);
|
||||
const auto K = in_gemmm_gemmk_desc.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(in_gemmm_gemmk_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
static auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
@@ -337,16 +339,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
const auto wei_gemmn_gemmk_desc =
|
||||
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
|
||||
|
||||
const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
|
||||
const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(wei_gemmn_gemmk_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
return wei_gemmn_gemmk_desc;
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
@@ -364,15 +357,21 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
|
||||
|
||||
// Force MN padding on the output tensor. This allows to use Gemm default or only K padding
|
||||
// and remove some instructions in the hot loop (same approach used for gemm universal).
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
constexpr auto matrix_padder_trans =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{NPerBlock, MPerBlock, KPerBlock};
|
||||
return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
constexpr auto matrix_padder_MN_padding_trans =
|
||||
MatrixPadder<GemmSpecialization::MNPadding, index_t, index_t, index_t>{
|
||||
NPerBlock, MPerBlock, KPerBlock};
|
||||
return matrix_padder_MN_padding_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
constexpr auto matrix_padder_MN_padding =
|
||||
MatrixPadder<GemmSpecialization::MNPadding, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock};
|
||||
return matrix_padder_MN_padding.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -452,10 +451,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
BlkGemmPipelineVer,
|
||||
AComputeDataType,
|
||||
BComputeDataType,
|
||||
false, // PermuteA
|
||||
false, // PermuteB
|
||||
false, // IsBPreShuffled
|
||||
true>; // ForceThreadTileTransfer
|
||||
false, // PermuteA
|
||||
false, // PermuteB
|
||||
false, // IsBPreShuffled
|
||||
UseThreadTileTransfer>; // ForceThreadTileTransfer
|
||||
|
||||
// TODO: Previously available template param DoElementwiseBeforeCShuffle!
|
||||
|
||||
@@ -529,7 +528,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
false, // PermuteB
|
||||
false, // PermuteA
|
||||
false, // IsBPreShuffled
|
||||
true>; // ForceThreadTileTransfer
|
||||
true>; // ForceThreadTileTransfer (always force it because of limitations in the transfer)
|
||||
|
||||
using GridwiseGemmCTranspose =
|
||||
std::conditional_t<CTranspose, GridwiseGemmSwappedParams, GridwiseGemm>;
|
||||
@@ -626,10 +625,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
I1>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
|
||||
dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
|
||||
dummy_conv_to_gemm_transformer))>;
|
||||
using AGridDesc_M_K =
|
||||
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_N_K =
|
||||
remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -695,10 +694,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{
|
||||
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
MakeBGridDescriptor_BK0_N_BK1<BLayout>(conv_to_gemm_transformer_)},
|
||||
a_grid_desc_m_k_{MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
|
||||
b_grid_desc_n_k_{MakeBGridDescriptor_N_K<BLayout>(conv_to_gemm_transformer_)},
|
||||
compute_ptr_offset_of_groups_{},
|
||||
compute_ptr_offset_of_n_{},
|
||||
a_element_op_{a_element_op},
|
||||
@@ -798,8 +795,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
}
|
||||
|
||||
{
|
||||
const index_t GemmM = a_grid_desc_ak0_m_ak1_.GetLength(I1);
|
||||
const index_t GemmN = b_grid_desc_bk0_n_bk1_.GetLength(I1);
|
||||
const index_t GemmM = a_grid_desc_m_k_.GetLength(I0);
|
||||
const index_t GemmN = b_grid_desc_n_k_.GetLength(I0);
|
||||
const auto MBlock = CTranspose ? GridwiseGemmCTranspose::CalculateMBlock(GemmN)
|
||||
: GridwiseGemmCTranspose::CalculateMBlock(GemmM);
|
||||
const auto NBlock = CTranspose ? GridwiseGemmCTranspose::CalculateNBlock(GemmM)
|
||||
@@ -883,7 +880,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
is_same_v<ALayout, tensor_layout::convolution::NDHWGC>)
|
||||
{
|
||||
size_as_buffers[i] =
|
||||
(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() +
|
||||
(a_grid_desc_m_k_.GetElementSpaceSize() +
|
||||
(num_group_ - NumGroupsToMerge) * (a_g_n_c_wis_strides_[0])) *
|
||||
sizeof(ADataType_single) / GridwiseGemm::APackedSize;
|
||||
}
|
||||
@@ -891,13 +888,13 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
{
|
||||
if(CTranspose && a_g_n_c_wis_lengths_[I1] > 1)
|
||||
{
|
||||
size_as_buffers[i] = (a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() +
|
||||
size_as_buffers[i] = (a_grid_desc_m_k_.GetElementSpaceSize() +
|
||||
(eff_num_group - 1) * (a_g_n_c_wis_strides_[0])) *
|
||||
sizeof(ADataType_single) / GridwiseGemm::APackedSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
size_as_buffers[i] = a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() *
|
||||
size_as_buffers[i] = a_grid_desc_m_k_.GetElementSpaceSize() *
|
||||
eff_num_group * sizeof(ADataType_single) /
|
||||
GridwiseGemm::APackedSize;
|
||||
}
|
||||
@@ -914,7 +911,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_single = remove_cvref_t<tuple_element_t<i.value, GemmBsDataType>>;
|
||||
size_bs_buffers[i] = b_grid_desc_bk0_n_bk1_.GetElementSpaceSize() * eff_num_group *
|
||||
size_bs_buffers[i] = b_grid_desc_n_k_.GetElementSpaceSize() * eff_num_group *
|
||||
sizeof(BDataType_single) / GridwiseGemm::BPackedSize;
|
||||
});
|
||||
|
||||
@@ -961,8 +958,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
|
||||
std::cout << "B[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
|
||||
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_m_k_ << std::endl;
|
||||
std::cout << "B[BK0, N, BK1]: " << b_grid_desc_n_k_ << std::endl;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
|
||||
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
|
||||
@@ -998,8 +995,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
@@ -1048,10 +1045,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
|
||||
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
|
||||
const index_t GemmK =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
const index_t GemmM = arg.a_grid_desc_m_k_.GetLength(I0);
|
||||
const index_t GemmN = arg.b_grid_desc_n_k_.GetLength(I0);
|
||||
const index_t GemmK = arg.a_grid_desc_m_k_.GetLength(I1);
|
||||
|
||||
const index_t num_workgroups_per_Conv_N =
|
||||
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
|
||||
@@ -1193,8 +1189,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.a_grid_desc_m_k_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
@@ -1210,8 +1206,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.a_grid_desc_m_k_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
@@ -1291,8 +1287,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
@@ -1308,8 +1304,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
@@ -1327,8 +1323,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3<
|
||||
GridwiseGemmCTranspose,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_N_K,
|
||||
DeviceOp::AGridDesc_M_K,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ComputePtrOffset,
|
||||
@@ -1342,8 +1338,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::AGridDesc_M_K,
|
||||
DeviceOp::BGridDesc_N_K,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ComputePtrOffset,
|
||||
@@ -1985,10 +1981,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
}
|
||||
|
||||
// check Gridwise GEMM
|
||||
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
|
||||
const index_t GemmK =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
const index_t GemmM = arg.a_grid_desc_m_k_.GetLength(I0);
|
||||
const index_t GemmN = arg.b_grid_desc_n_k_.GetLength(I0);
|
||||
const index_t GemmK = arg.a_grid_desc_m_k_.GetLength(I1);
|
||||
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
|
||||
@@ -66,8 +66,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
const CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
__shared__ uint8_t p_shared[LDS_size];
|
||||
|
||||
const auto gemm_desc_ptr =
|
||||
@@ -150,7 +154,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
gemm_desc_ptr[group_id].StrideE,
|
||||
1);
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
auto epilogue_args = EpilogueType{};
|
||||
constexpr TailNumber TailNum = TailNumber::Full;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
|
||||
@@ -41,8 +41,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const index_t group_count)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
@@ -89,13 +93,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
|
||||
auto splitk_batch_offset =
|
||||
typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]);
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
auto epilogue_args = EpilogueType{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
Block2CTileMap,
|
||||
typename GridwiseGemm::EpilogueCShuffle,
|
||||
EpilogueType,
|
||||
1, // Block2CTileMap MBlock index
|
||||
2 // Block2CTileMap NBlock index
|
||||
>(static_cast<void*>(p_shared),
|
||||
|
||||
Reference in New Issue
Block a user