mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add support for direct store in epilogue and padding support for wave transfer without transpose (#3465)
- Add support for direct store in epilogue instead of cshuffle - Add padding support for wave transfer without transpose - Add wave transfer with interleaved layout to support direct store - Enable new functionalities on GEMMs - Add optional new functionality support for grouped convolution fwd - Add some fast instances for grouped convolution fwd with new functionalities (proper tuning needed)
This commit is contained in:
@@ -60,8 +60,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const long_index_t c_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
|
||||
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
using EpilogueType =
|
||||
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
@@ -84,7 +90,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
|
||||
});
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
auto epilogue_args = EpilogueType{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_shift,
|
||||
|
||||
@@ -46,8 +46,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
std::is_same_v<c_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
using EpilogueType =
|
||||
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
// The normal approach to batching would be to increase the grid size by just stretching out
|
||||
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
|
||||
// functions not directly using the Z dimension for other calculations. As it turns out, k
|
||||
@@ -86,7 +92,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
|
||||
});
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
auto epilogue_args = EpilogueType{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_shift,
|
||||
|
||||
@@ -188,7 +188,10 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
PermuteB,
|
||||
false, // IsBPreShuffled
|
||||
false, // ForceThreadTileTransfer
|
||||
true>; // IsFusedKernel
|
||||
|
||||
using ReduceTrait = ReduceTrait_<ReduceAccDataType,
|
||||
ReducePtrsGlobal,
|
||||
|
||||
@@ -273,7 +273,10 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
PermuteB,
|
||||
false,
|
||||
false,
|
||||
true>;
|
||||
|
||||
// Welford 2nd part kernel
|
||||
template <typename DoPads, index_t MPerTile, index_t NPerTile>
|
||||
|
||||
@@ -187,7 +187,10 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
PermuteB,
|
||||
false,
|
||||
false,
|
||||
true>;
|
||||
|
||||
using ReduceTrait = ReduceTrait_<ReduceAccDataType,
|
||||
ReducePtrsGlobal,
|
||||
|
||||
@@ -48,8 +48,8 @@ namespace {
|
||||
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
|
||||
*/
|
||||
template <typename GridwiseGemm,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename AGridDesc_M_K,
|
||||
typename BGridDesc_N_K,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename ComputePtrOffset, // For Batch (group) and N
|
||||
@@ -63,8 +63,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_grouped_conv_fwd_wmma_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument karg,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const AGridDesc_M_K a_grid_desc_m_k,
|
||||
const BGridDesc_N_K b_grid_desc_n_k,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -82,13 +82,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
__shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>()];
|
||||
using EpilogueType =
|
||||
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
auto epilogue_args = EpilogueType{};
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
|
||||
|
||||
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ComputePtrOffset,
|
||||
@@ -110,8 +123,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = a_grid_desc_m_k;
|
||||
ignore = b_grid_desc_n_k;
|
||||
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
@@ -187,6 +200,7 @@ template <index_t NDimSpatial,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
bool UseThreadTileTransfer = true,
|
||||
typename AComputeDataType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
@@ -289,9 +303,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
NPerBlock / ClusterLengthNPerBlock>{};
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
|
||||
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
@@ -307,21 +319,11 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
|
||||
const auto M = in_gemmm_gemmk_desc.GetLength(I0);
|
||||
const auto K = in_gemmm_gemmk_desc.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(in_gemmm_gemmk_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
static auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
@@ -337,16 +339,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
const auto wei_gemmn_gemmk_desc =
|
||||
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
|
||||
|
||||
const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
|
||||
const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(wei_gemmn_gemmk_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
return wei_gemmn_gemmk_desc;
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
@@ -364,15 +357,21 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
|
||||
|
||||
// Force MN padding on the output tensor. This allows to use Gemm default or only K padding
|
||||
// and remove some instructions in the hot loop (same approach used for gemm universal).
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
constexpr auto matrix_padder_trans =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{NPerBlock, MPerBlock, KPerBlock};
|
||||
return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
constexpr auto matrix_padder_MN_padding_trans =
|
||||
MatrixPadder<GemmSpecialization::MNPadding, index_t, index_t, index_t>{
|
||||
NPerBlock, MPerBlock, KPerBlock};
|
||||
return matrix_padder_MN_padding_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
constexpr auto matrix_padder_MN_padding =
|
||||
MatrixPadder<GemmSpecialization::MNPadding, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock};
|
||||
return matrix_padder_MN_padding.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -452,10 +451,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
BlkGemmPipelineVer,
|
||||
AComputeDataType,
|
||||
BComputeDataType,
|
||||
false, // PermuteA
|
||||
false, // PermuteB
|
||||
false, // IsBPreShuffled
|
||||
true>; // ForceThreadTileTransfer
|
||||
false, // PermuteA
|
||||
false, // PermuteB
|
||||
false, // IsBPreShuffled
|
||||
UseThreadTileTransfer>; // ForceThreadTileTransfer
|
||||
|
||||
// TODO: Previously available template param DoElementwiseBeforeCShuffle!
|
||||
|
||||
@@ -529,7 +528,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
false, // PermuteB
|
||||
false, // PermuteA
|
||||
false, // IsBPreShuffled
|
||||
true>; // ForceThreadTileTransfer
|
||||
true>; // ForceThreadTileTransfer (always force it because of limitations in the transfer)
|
||||
|
||||
using GridwiseGemmCTranspose =
|
||||
std::conditional_t<CTranspose, GridwiseGemmSwappedParams, GridwiseGemm>;
|
||||
@@ -626,10 +625,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
I1>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
|
||||
dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
|
||||
dummy_conv_to_gemm_transformer))>;
|
||||
using AGridDesc_M_K =
|
||||
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_N_K =
|
||||
remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -695,10 +694,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{
|
||||
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
MakeBGridDescriptor_BK0_N_BK1<BLayout>(conv_to_gemm_transformer_)},
|
||||
a_grid_desc_m_k_{MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
|
||||
b_grid_desc_n_k_{MakeBGridDescriptor_N_K<BLayout>(conv_to_gemm_transformer_)},
|
||||
compute_ptr_offset_of_groups_{},
|
||||
compute_ptr_offset_of_n_{},
|
||||
a_element_op_{a_element_op},
|
||||
@@ -798,8 +795,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
}
|
||||
|
||||
{
|
||||
const index_t GemmM = a_grid_desc_ak0_m_ak1_.GetLength(I1);
|
||||
const index_t GemmN = b_grid_desc_bk0_n_bk1_.GetLength(I1);
|
||||
const index_t GemmM = a_grid_desc_m_k_.GetLength(I0);
|
||||
const index_t GemmN = b_grid_desc_n_k_.GetLength(I0);
|
||||
const auto MBlock = CTranspose ? GridwiseGemmCTranspose::CalculateMBlock(GemmN)
|
||||
: GridwiseGemmCTranspose::CalculateMBlock(GemmM);
|
||||
const auto NBlock = CTranspose ? GridwiseGemmCTranspose::CalculateNBlock(GemmM)
|
||||
@@ -883,7 +880,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
is_same_v<ALayout, tensor_layout::convolution::NDHWGC>)
|
||||
{
|
||||
size_as_buffers[i] =
|
||||
(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() +
|
||||
(a_grid_desc_m_k_.GetElementSpaceSize() +
|
||||
(num_group_ - NumGroupsToMerge) * (a_g_n_c_wis_strides_[0])) *
|
||||
sizeof(ADataType_single) / GridwiseGemm::APackedSize;
|
||||
}
|
||||
@@ -891,13 +888,13 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
{
|
||||
if(CTranspose && a_g_n_c_wis_lengths_[I1] > 1)
|
||||
{
|
||||
size_as_buffers[i] = (a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() +
|
||||
size_as_buffers[i] = (a_grid_desc_m_k_.GetElementSpaceSize() +
|
||||
(eff_num_group - 1) * (a_g_n_c_wis_strides_[0])) *
|
||||
sizeof(ADataType_single) / GridwiseGemm::APackedSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
size_as_buffers[i] = a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() *
|
||||
size_as_buffers[i] = a_grid_desc_m_k_.GetElementSpaceSize() *
|
||||
eff_num_group * sizeof(ADataType_single) /
|
||||
GridwiseGemm::APackedSize;
|
||||
}
|
||||
@@ -914,7 +911,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_single = remove_cvref_t<tuple_element_t<i.value, GemmBsDataType>>;
|
||||
size_bs_buffers[i] = b_grid_desc_bk0_n_bk1_.GetElementSpaceSize() * eff_num_group *
|
||||
size_bs_buffers[i] = b_grid_desc_n_k_.GetElementSpaceSize() * eff_num_group *
|
||||
sizeof(BDataType_single) / GridwiseGemm::BPackedSize;
|
||||
});
|
||||
|
||||
@@ -961,8 +958,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
|
||||
std::cout << "B[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
|
||||
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_m_k_ << std::endl;
|
||||
std::cout << "B[BK0, N, BK1]: " << b_grid_desc_n_k_ << std::endl;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
|
||||
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
|
||||
@@ -998,8 +995,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
@@ -1048,10 +1045,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
|
||||
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
|
||||
const index_t GemmK =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
const index_t GemmM = arg.a_grid_desc_m_k_.GetLength(I0);
|
||||
const index_t GemmN = arg.b_grid_desc_n_k_.GetLength(I0);
|
||||
const index_t GemmK = arg.a_grid_desc_m_k_.GetLength(I1);
|
||||
|
||||
const index_t num_workgroups_per_Conv_N =
|
||||
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
|
||||
@@ -1193,8 +1189,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.a_grid_desc_m_k_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
@@ -1210,8 +1206,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.a_grid_desc_m_k_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
@@ -1291,8 +1287,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
@@ -1308,8 +1304,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
@@ -1327,8 +1323,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3<
|
||||
GridwiseGemmCTranspose,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_N_K,
|
||||
DeviceOp::AGridDesc_M_K,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ComputePtrOffset,
|
||||
@@ -1342,8 +1338,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::AGridDesc_M_K,
|
||||
DeviceOp::BGridDesc_N_K,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ComputePtrOffset,
|
||||
@@ -1985,10 +1981,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
}
|
||||
|
||||
// check Gridwise GEMM
|
||||
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
|
||||
const index_t GemmK =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
const index_t GemmM = arg.a_grid_desc_m_k_.GetLength(I0);
|
||||
const index_t GemmN = arg.b_grid_desc_n_k_.GetLength(I0);
|
||||
const index_t GemmK = arg.a_grid_desc_m_k_.GetLength(I1);
|
||||
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
|
||||
@@ -66,8 +66,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
const CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
__shared__ uint8_t p_shared[LDS_size];
|
||||
|
||||
const auto gemm_desc_ptr =
|
||||
@@ -150,7 +154,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
gemm_desc_ptr[group_id].StrideE,
|
||||
1);
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
auto epilogue_args = EpilogueType{};
|
||||
constexpr TailNumber TailNum = TailNumber::Full;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
|
||||
@@ -41,8 +41,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const index_t group_count)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
@@ -89,13 +93,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
|
||||
auto splitk_batch_offset =
|
||||
typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]);
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
auto epilogue_args = EpilogueType{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
Block2CTileMap,
|
||||
typename GridwiseGemm::EpilogueCShuffle,
|
||||
EpilogueType,
|
||||
1, // Block2CTileMap MBlock index
|
||||
2 // Block2CTileMap NBlock index
|
||||
>(static_cast<void*>(p_shared),
|
||||
|
||||
@@ -59,6 +59,8 @@ struct EpilogueCShuffleBase
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma>>;
|
||||
|
||||
__device__ static constexpr bool IsLDSNeeded() { return true; }
|
||||
|
||||
// *Caution Here repeat is shuffle repeat
|
||||
__device__ static constexpr auto
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
|
||||
145
include/ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp
Normal file
145
include/ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp
Normal file
@@ -0,0 +1,145 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename CDEElementwiseOperation,
|
||||
typename BlockwiseGemmPipe>
|
||||
struct EpilogueDirectStore
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
|
||||
__device__ static constexpr bool IsLDSNeeded() { return false; }
|
||||
|
||||
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
typename CThreadBuf,
|
||||
typename DsGridPointer,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
|
||||
__device__ static void Run(CThreadBuf& c_thread_buf,
|
||||
DsGridPointer,
|
||||
EDataType* p_e_grid,
|
||||
void*,
|
||||
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
CDEElementwiseOperation& cde_element_op,
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id)
|
||||
{
|
||||
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// C mapping in single thread.
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
BlockwiseGemmPipe::
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
// C mapping in single block
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
|
||||
BlockwiseGemmPipe::
|
||||
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
constexpr auto MWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I1);
|
||||
constexpr auto MSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I2);
|
||||
constexpr auto NWave =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I4);
|
||||
constexpr auto NThreadPerSubGroup =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I5);
|
||||
constexpr auto MAccVgprs =
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
|
||||
.GetLength(I6);
|
||||
|
||||
// origin
|
||||
const auto c_thread_mtx_on_block =
|
||||
BlockwiseGemmPipe::CalculateCThreadOriginDataIndex(I0, I0);
|
||||
|
||||
const auto m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_grid_idx =
|
||||
m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(c_thread_mtx_on_block[I0]));
|
||||
|
||||
const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_grid_idx =
|
||||
n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(c_thread_mtx_on_block[I1]));
|
||||
|
||||
// E grid descriptor
|
||||
const auto c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
transform_tensor_descriptor(
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(make_freeze_transform(block_m_id),
|
||||
make_unmerge_transform(make_tuple(Number<MRepeat>{},
|
||||
Number<MWave>{},
|
||||
Number<MSubGroup>{},
|
||||
Number<MAccVgprs>{})),
|
||||
make_freeze_transform(block_n_id),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<NWave>{}, Number<NThreadPerSubGroup>{}, Number<NRepeat>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(
|
||||
Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<4, 5, 3>{}));
|
||||
|
||||
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
EDataType,
|
||||
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
decltype(c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
CDEElementwiseOperation,
|
||||
Sequence<MRepeat, I1, I1, NRepeat, I1, I1, MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
3,
|
||||
NRepeat, // VectorSize
|
||||
EGlobalMemoryDataOperation,
|
||||
1,
|
||||
false>{c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
make_multi_index(m_thread_data_on_grid_idx[I0],
|
||||
m_thread_data_on_grid_idx[I1],
|
||||
m_thread_data_on_grid_idx[I2],
|
||||
n_thread_data_on_grid_idx[I0],
|
||||
n_thread_data_on_grid_idx[I1],
|
||||
n_thread_data_on_grid_idx[I2],
|
||||
m_thread_data_on_grid_idx[I3]),
|
||||
cde_element_op};
|
||||
|
||||
c_thread_copy.Run(
|
||||
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
e_grid_buf);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -77,26 +77,79 @@ struct ABTransferWaveTiles
|
||||
static constexpr index_t KRepeat_ = KPerBlock / (KWaves_ * KPack);
|
||||
static constexpr index_t MNRepeat_ = MNPerBlock / (MNWaves_ * MNPerWmma);
|
||||
|
||||
template <bool PadMN, bool PadK, typename GridDescriptorBase>
|
||||
__host__ __device__ static auto PadGridDescriptor(GridDescriptorBase& base_desc,
|
||||
index_t sizeMN,
|
||||
index_t MNPad,
|
||||
index_t sizeK,
|
||||
index_t KPad,
|
||||
index_t,
|
||||
index_t)
|
||||
{
|
||||
if constexpr(PadMN && PadK)
|
||||
{
|
||||
// pad both MN and K
|
||||
return transform_tensor_descriptor(
|
||||
base_desc,
|
||||
make_tuple(make_right_pad_transform(sizeMN, MNPad - sizeMN),
|
||||
make_right_pad_transform(sizeK, KPad - sizeK)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(PadMN && !PadK)
|
||||
{
|
||||
// pad MN, but not K
|
||||
return transform_tensor_descriptor(
|
||||
base_desc,
|
||||
make_tuple(make_right_pad_transform(sizeMN, MNPad - sizeMN),
|
||||
make_pass_through_transform(sizeK)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(!PadMN && PadK)
|
||||
{
|
||||
// pad K, but not MN
|
||||
return transform_tensor_descriptor(
|
||||
base_desc,
|
||||
make_tuple(make_pass_through_transform(sizeMN),
|
||||
make_right_pad_transform(sizeK, KPad - sizeK)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad MN or K
|
||||
return base_desc;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool PadMN, bool PadK, typename GridDescriptorBase>
|
||||
__host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc,
|
||||
index_t sizeMN,
|
||||
index_t,
|
||||
index_t MNPad,
|
||||
index_t sizeK,
|
||||
index_t,
|
||||
index_t KPad,
|
||||
index_t,
|
||||
index_t)
|
||||
{
|
||||
// Notes: padding is currently not supported
|
||||
static_assert(!PadMN && !PadK, "padding is currently not supported");
|
||||
// Notes: padding is currently not supported with transpose
|
||||
static_assert(!((PadMN || PadK) && ABDoTranspose),
|
||||
"padding is currently not supported with transpose");
|
||||
|
||||
const index_t MN_grid = !PadMN ? sizeMN : MNPad;
|
||||
const index_t K_grid = !PadK ? sizeK : KPad;
|
||||
|
||||
const auto base_desc_padded =
|
||||
PadGridDescriptor<PadMN, PadK>(base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0);
|
||||
|
||||
// Divide the base descriptor MN_K into tiles
|
||||
const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor(
|
||||
base_desc,
|
||||
base_desc_padded,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(
|
||||
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{}), Number<MNPerWmma>{})),
|
||||
make_unmerge_transform(make_tuple(math::integer_divide_ceil(sizeK, Number<KPack>{}),
|
||||
Number<KPack>{}))),
|
||||
math::integer_divide_ceil(MN_grid, Number<MNPerWmma>{}), Number<MNPerWmma>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
math::integer_divide_ceil(K_grid, Number<KPack>{}), Number<KPack>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
|
||||
|
||||
@@ -112,9 +165,9 @@ struct ABTransferWaveTiles
|
||||
transform_tensor_descriptor(
|
||||
ab_grid_desc_mntiles_ktiles,
|
||||
make_tuple(make_pass_through_transform(
|
||||
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{})),
|
||||
math::integer_divide_ceil(MN_grid, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(
|
||||
math::integer_divide_ceil(sizeK, Number<KPack>{})),
|
||||
math::integer_divide_ceil(K_grid, Number<KPack>{})),
|
||||
make_pass_through_transform(Number<MNPerWmma>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNKRow>{}, Number<KPack / MNKRow>{}))),
|
||||
@@ -127,8 +180,8 @@ struct ABTransferWaveTiles
|
||||
ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
|
||||
make_tuple(
|
||||
make_pass_through_transform(
|
||||
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(math::integer_divide_ceil(sizeK, Number<KPack>{})),
|
||||
math::integer_divide_ceil(MN_grid, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(math::integer_divide_ceil(K_grid, Number<KPack>{})),
|
||||
make_pass_through_transform(Number<MNPerWmma>{}),
|
||||
make_pass_through_transform(Number<MNKRow>{}),
|
||||
make_freeze_transform(I0)),
|
||||
@@ -143,9 +196,9 @@ struct ABTransferWaveTiles
|
||||
transform_tensor_descriptor(
|
||||
ab_grid_desc_mntiles_ktiles,
|
||||
make_tuple(make_pass_through_transform(
|
||||
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{})),
|
||||
math::integer_divide_ceil(MN_grid, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(
|
||||
math::integer_divide_ceil(sizeK, Number<KPack>{})),
|
||||
math::integer_divide_ceil(K_grid, Number<KPack>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNKRow>{}, Number<MNPerWmma / MNKRow>{})),
|
||||
make_pass_through_transform(Number<KPack>{})),
|
||||
@@ -157,8 +210,8 @@ struct ABTransferWaveTiles
|
||||
ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
|
||||
make_tuple(
|
||||
make_pass_through_transform(
|
||||
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(math::integer_divide_ceil(sizeK, Number<KPack>{})),
|
||||
math::integer_divide_ceil(MN_grid, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(math::integer_divide_ceil(K_grid, Number<KPack>{})),
|
||||
make_pass_through_transform(Number<MNKRow>{}),
|
||||
make_freeze_transform(I0),
|
||||
make_pass_through_transform(Number<KPack>{})),
|
||||
|
||||
@@ -0,0 +1,275 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/amd_address_space.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp"
|
||||
#include "ck/utility/math.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename ABLayout,
|
||||
typename ABMajorLayout,
|
||||
typename LDSTypeAB,
|
||||
index_t BlockSize,
|
||||
index_t MNPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MNPerWmma,
|
||||
index_t KPack,
|
||||
index_t ABK1Value,
|
||||
index_t WaveSize,
|
||||
index_t MNWaves_Gemm>
|
||||
struct ABTransferWaveTilesInterleave : ABTransferWaveTiles<ABLayout,
|
||||
ABMajorLayout,
|
||||
LDSTypeAB,
|
||||
BlockSize,
|
||||
MNPerBlock,
|
||||
KPerBlock,
|
||||
MNPerWmma,
|
||||
KPack,
|
||||
ABK1Value,
|
||||
WaveSize>
|
||||
{
|
||||
using Base = ABTransferWaveTiles<ABLayout,
|
||||
ABMajorLayout,
|
||||
LDSTypeAB,
|
||||
BlockSize,
|
||||
MNPerBlock,
|
||||
KPerBlock,
|
||||
MNPerWmma,
|
||||
KPack,
|
||||
ABK1Value,
|
||||
WaveSize>;
|
||||
|
||||
using Base::ABDoTranspose;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::I2;
|
||||
using Base::I3;
|
||||
using Base::MNKRow;
|
||||
|
||||
using Base::GetBlockLaneIdx;
|
||||
using Base::GetBlockStep;
|
||||
using Base::GetGridLaneIdx;
|
||||
using Base::GetWaveIdx;
|
||||
using Base::PadGridDescriptor;
|
||||
using typename Base::ThisThreadBlock;
|
||||
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
|
||||
static_assert(!ABDoTranspose, "wave tile interleaved transfer does not support transpose yet");
|
||||
|
||||
using Base::KRepeat_;
|
||||
using Base::KWaves_;
|
||||
using Base::MNRepeat_;
|
||||
|
||||
static constexpr index_t MNWaves_Grid = MNWaves_Gemm;
|
||||
static constexpr index_t KWaves_Grid = (BlockSize / WaveSize) / MNWaves_Gemm;
|
||||
static constexpr index_t KRepeat_Grid = KPerBlock / (KWaves_Grid * KPack);
|
||||
static constexpr index_t MNRepeat_Grid = MNPerBlock / (MNWaves_Grid * MNPerWmma);
|
||||
|
||||
template <bool PadMN, bool PadK, typename GridDescriptorBase>
|
||||
__host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc,
|
||||
index_t sizeMN,
|
||||
index_t MNPad,
|
||||
index_t sizeK,
|
||||
index_t KPad,
|
||||
index_t,
|
||||
index_t)
|
||||
{
|
||||
const auto base_desc_padded = Base::template PadGridDescriptor<PadMN, PadK>(
|
||||
base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0);
|
||||
|
||||
const index_t MN_grid = !PadMN ? sizeMN : MNPad;
|
||||
const index_t K_grid = !PadK ? sizeK : KPad;
|
||||
|
||||
// Divide the base descriptor MN_K into tiles
|
||||
const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor(
|
||||
base_desc_padded,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
math::integer_divide_ceil(MN_grid, Number<MNPerWmma * MNRepeat_Grid>{}),
|
||||
Number<MNPerWmma * MNRepeat_Grid>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
math::integer_divide_ceil(K_grid, Number<KPack>{}), Number<KPack>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
|
||||
|
||||
// The distinction is needed to get the same global indices for both layouts
|
||||
// Divide each tile in 2 16x8 subtile
|
||||
// MNTiles - KTiles - MNKRow - LaneLocal - VectorSize
|
||||
// MNKRow = 0-1
|
||||
// LaneLocal = 0-15
|
||||
// VectorSize must be 8
|
||||
if constexpr(!ABDoTranspose)
|
||||
{
|
||||
const auto ab_grid_desc_mntiles_ktiles_mnrepeat = transform_tensor_descriptor(
|
||||
ab_grid_desc_mntiles_ktiles,
|
||||
make_tuple(
|
||||
make_pass_through_transform(
|
||||
math::integer_divide_ceil(MN_grid, Number<MNPerWmma * MNRepeat_Grid>{})),
|
||||
make_pass_through_transform(math::integer_divide_ceil(K_grid, Number<KPack>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNPerWmma>{}, Number<MNRepeat_Grid>{})),
|
||||
make_pass_through_transform(Number<KPack>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<3, 2>{}, Sequence<4>{}));
|
||||
|
||||
const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 =
|
||||
transform_tensor_descriptor(
|
||||
ab_grid_desc_mntiles_ktiles_mnrepeat,
|
||||
make_tuple(make_pass_through_transform(math::integer_divide_ceil(
|
||||
MN_grid, Number<MNPerWmma * MNRepeat_Grid>{})),
|
||||
make_pass_through_transform(
|
||||
math::integer_divide_ceil(K_grid, Number<KPack>{})),
|
||||
make_pass_through_transform(Number<MNRepeat_Grid>{}),
|
||||
make_pass_through_transform(Number<MNPerWmma>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNKRow>{}, Number<KPack / MNKRow>{}))),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4, 5>{}));
|
||||
|
||||
// Freeze VectorSize to first element of the loading chunk (for convenience)
|
||||
// Swap MNPerWmma and MNKRow for consistency with transpose descriptor
|
||||
return transform_tensor_descriptor(
|
||||
ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
|
||||
make_tuple(
|
||||
make_pass_through_transform(
|
||||
math::integer_divide_ceil(MN_grid, Number<MNPerWmma * MNRepeat_Grid>{})),
|
||||
make_pass_through_transform(math::integer_divide_ceil(K_grid, Number<KPack>{})),
|
||||
make_pass_through_transform(Number<MNRepeat_Grid>{}),
|
||||
make_pass_through_transform(Number<MNPerWmma>{}),
|
||||
make_pass_through_transform(Number<MNKRow>{}),
|
||||
make_freeze_transform(I0)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<3>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<>{}));
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetBlockDescriptor()
|
||||
{
|
||||
// LDS memory layouts:
|
||||
// lanes within tiles stored contiguously in chunks of 8 elements
|
||||
// tiles are then stored first in K dimension
|
||||
// MNTiles - KTiles - MNKRow - LaneLocal - VectorSize
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<MNWaves_Grid>{},
|
||||
Number<KRepeat_Grid * KWaves_Grid>{},
|
||||
Number<MNRepeat_Grid>{},
|
||||
Number<MNKRow>{},
|
||||
Number<MNPerWmma>{},
|
||||
Number<ABK1Value>{}),
|
||||
make_tuple(Number<KPack * MNPerWmma * KWaves_Grid * KRepeat_Grid>{},
|
||||
Number<KPack * MNPerWmma>{},
|
||||
Number<KPack * MNPerWmma * KWaves_Grid * KRepeat_Grid * MNWaves_Grid>{},
|
||||
Number<ABK1Value * MNPerWmma>{},
|
||||
Number<ABK1Value>{},
|
||||
I1));
|
||||
}();
|
||||
|
||||
// Freeze VectorSize to first element of the chunk (for convenience)
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_pass_through_transform(Number<MNWaves_Grid>{}),
|
||||
make_pass_through_transform(Number<KRepeat_Grid * KWaves_Grid>{}),
|
||||
make_pass_through_transform(Number<MNRepeat_Grid>{}),
|
||||
make_pass_through_transform(Number<MNKRow>{}),
|
||||
make_pass_through_transform(Number<MNPerWmma>{}),
|
||||
make_freeze_transform(I0)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<>{}));
|
||||
}
|
||||
|
||||
template <typename GridDescriptor,
|
||||
typename BlockDescriptor,
|
||||
typename ABsDataType,
|
||||
typename ABElementwiseOperation,
|
||||
index_t GlobalBufferNum>
|
||||
__device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor,
|
||||
BlockDescriptor& block_descriptor,
|
||||
ABElementwiseOperation& ab_element_op,
|
||||
const index_t block_mn_id,
|
||||
const index_t)
|
||||
{
|
||||
// Note: GlobalBufferNum is currently not used but it will be needed
|
||||
// once we add other pipelines. It is currently needed only for
|
||||
// consistency with the thread tiles approach
|
||||
static_assert(GlobalBufferNum == 1, "single global buffer is only supported");
|
||||
constexpr index_t NumABTensor = ABsDataType::Size();
|
||||
static_assert(NumABTensor == 1, "multiAB currently not supported");
|
||||
|
||||
using ABDataType = remove_cvref_t<tuple_element_t<0, ABsDataType>>;
|
||||
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
index_t wave_idK = wave_idx[I1];
|
||||
index_t wave_idMN = wave_idx[I0];
|
||||
|
||||
const auto grid_lane_id = Base::template GetGridLaneIdx<ABDataType>();
|
||||
index_t lane_group_grid = grid_lane_id[I0];
|
||||
index_t lane_local_id_grid = grid_lane_id[I1];
|
||||
|
||||
const auto block_lane_id = GetBlockLaneIdx();
|
||||
index_t lane_group_block = block_lane_id[I0];
|
||||
index_t lane_local_id_block = block_lane_id[I1];
|
||||
|
||||
constexpr index_t MNRepeatRatio = MNRepeat_Grid / MNRepeat_;
|
||||
return ThreadGroupTransferGlobal<decltype(grid_descriptor[I0]),
|
||||
BlockDescriptor,
|
||||
ABDataType,
|
||||
ABDataType,
|
||||
ABElementwiseOperation,
|
||||
Sequence<I1, KRepeat_, MNRepeat_, I1, I1>,
|
||||
Sequence<I1, KWaves_, I1, I1, I1>,
|
||||
Sequence<I0, I1, I2, I3, I4>,
|
||||
ABK1Value,
|
||||
ABDoTranspose>(
|
||||
grid_descriptor[I0],
|
||||
block_descriptor,
|
||||
make_multi_index(block_mn_id * MNWaves_Grid + wave_idMN / MNRepeatRatio,
|
||||
wave_idK * KRepeat_Grid,
|
||||
(wave_idMN % MNRepeatRatio) * MNRepeat_,
|
||||
lane_group_grid,
|
||||
lane_local_id_grid),
|
||||
make_multi_index(wave_idMN / MNRepeatRatio,
|
||||
wave_idK * KRepeat_,
|
||||
(wave_idMN % MNRepeatRatio) * MNRepeat_,
|
||||
lane_group_block,
|
||||
lane_local_id_block),
|
||||
ab_element_op);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetBlockStep()
|
||||
{
|
||||
// Grid descriptor step (MoveSrcSliceWindow)
|
||||
return make_multi_index(I0, KWaves_ * KRepeat_, I0, I0, I0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -177,7 +177,8 @@ template <typename ALayout,
|
||||
bool PermuteA,
|
||||
bool PermuteB,
|
||||
bool IsBPreShuffled = false,
|
||||
bool ForceThreadTileTransfer = false>
|
||||
bool ForceThreadTileTransfer = false,
|
||||
bool IsFusedKernel = false>
|
||||
struct GridwiseGemm_wmma_cshuffle_v3
|
||||
: GridwiseGemm_wmma_cshuffle_v3_base<
|
||||
ALayout,
|
||||
@@ -231,7 +232,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
PermuteA,
|
||||
PermuteB,
|
||||
IsBPreShuffled,
|
||||
ForceThreadTileTransfer>
|
||||
ForceThreadTileTransfer,
|
||||
IsFusedKernel>
|
||||
{
|
||||
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
|
||||
ALayout,
|
||||
@@ -285,7 +287,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
PermuteA,
|
||||
PermuteB,
|
||||
IsBPreShuffled,
|
||||
ForceThreadTileTransfer>;
|
||||
ForceThreadTileTransfer,
|
||||
IsFusedKernel>;
|
||||
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
|
||||
@@ -24,6 +25,7 @@
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"
|
||||
@@ -50,13 +52,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
using EpilogueType =
|
||||
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
auto epilogue_args = EpilogueType{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args);
|
||||
@@ -167,7 +175,8 @@ template <typename ALayout,
|
||||
bool PermuteA,
|
||||
bool PermuteB,
|
||||
bool IsBPreShuffled = false,
|
||||
bool ForceThreadTileTransfer = false> // only needed for convolution (limitation)
|
||||
bool ForceThreadTileTransfer = false, // only needed for convolution (limitation)
|
||||
bool IsFusedKernel = false>
|
||||
struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
{
|
||||
|
||||
@@ -182,6 +191,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
static constexpr index_t NumATensor = AsDataType::Size();
|
||||
static constexpr index_t NumBTensor = BsDataType::Size();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
using LDSTypeA =
|
||||
typename std::conditional<(NumATensor > 1),
|
||||
@@ -232,30 +242,44 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t WaveSize =
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
|
||||
.wave_size;
|
||||
|
||||
// Limitations of the current implementation:
|
||||
// - no multiAB
|
||||
// - GemmSpecialization Default
|
||||
// - pipeline v1 because v3 is buggy (fixed in batched gemm gemm implementation)
|
||||
// AK1Value == 8 is not really a limitation but a requirement for the method so
|
||||
// it will stay
|
||||
// - GemmSpecialization Default with transpose
|
||||
#ifdef __gfx12__
|
||||
static constexpr bool IsAWaveTransferApplicable =
|
||||
!ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
|
||||
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
|
||||
!is_same_v<ALayout, tensor_layout::gemm::RowMajor>) ||
|
||||
is_same_v<ALayout, tensor_layout::gemm::RowMajor>) &&
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled;
|
||||
|
||||
static constexpr bool IsBWaveTransferApplicable =
|
||||
!ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
|
||||
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
|
||||
!is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) ||
|
||||
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) &&
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
|
||||
|
||||
static constexpr bool IsWaveTileInterleavedFitting =
|
||||
(NPerBlock / NPerWmma / NRepeat) * (KPerBlock / KPack) >= (BlockSize / WaveSize);
|
||||
|
||||
// We need to investigate if it makes sense to remove cshuffle for smaller types
|
||||
// Currently we use direct store for NRepeat equal to 4 or 8. For 16 bit type we use at
|
||||
// least buffer store 64 bit for 16 contiguous threads -> 128 bytes in total (full cache line)
|
||||
static constexpr bool UseDirectStore = is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
|
||||
sizeof(ComputeTypeB) == 2 && sizeof(EDataType) == 2 &&
|
||||
NumDTensor == 0 && (NRepeat == 4 || NRepeat == 8) &&
|
||||
!IsFusedKernel && IsWaveTileInterleavedFitting;
|
||||
#else
|
||||
static constexpr bool IsAWaveTransferApplicable = false;
|
||||
static constexpr bool IsBWaveTransferApplicable = false;
|
||||
static constexpr bool UseDirectStore = false;
|
||||
#endif
|
||||
|
||||
static constexpr index_t WaveSize =
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
|
||||
.wave_size;
|
||||
static constexpr bool UseBlockPaddingA =
|
||||
ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4;
|
||||
using ATransfer = typename std::conditional<
|
||||
@@ -293,7 +317,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
static constexpr bool UseBlockPaddingB =
|
||||
BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4;
|
||||
|
||||
using BTransfer = typename std::conditional<
|
||||
IsBPreShuffled,
|
||||
ABTransferThreadTilesPreShuffle<BLayout,
|
||||
@@ -309,16 +332,29 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
BThreadTransferSrcResetCoordinateAfterRun>,
|
||||
typename std::conditional<
|
||||
IsBWaveTransferApplicable,
|
||||
ABTransferWaveTiles<BLayout,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
LDSTypeB,
|
||||
BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
NPerWmma,
|
||||
KPack,
|
||||
BK1Value,
|
||||
WaveSize>,
|
||||
typename std::conditional<
|
||||
UseDirectStore,
|
||||
ABTransferWaveTilesInterleave<BLayout,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
LDSTypeB,
|
||||
BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
NPerWmma,
|
||||
KPack,
|
||||
BK1Value,
|
||||
WaveSize,
|
||||
NPerBlock / NPerWmma / NRepeat>,
|
||||
ABTransferWaveTiles<BLayout,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
LDSTypeB,
|
||||
BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
NPerWmma,
|
||||
KPack,
|
||||
BK1Value,
|
||||
WaveSize>>::type,
|
||||
ABTransferThreadTiles<BLayout,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
LDSTypeB,
|
||||
@@ -490,6 +526,19 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
Number<NumATensor>{});
|
||||
}
|
||||
|
||||
template <typename GridDescBase>
|
||||
__device__ static auto MakeAGridDescriptor_AK0_M_AK1(const GridDescBase& base_desc)
|
||||
{
|
||||
const auto M = base_desc.GetLength(I0);
|
||||
const auto K = base_desc.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / AK1Value;
|
||||
|
||||
constexpr bool padM = false;
|
||||
constexpr bool padK = false;
|
||||
return ATransfer::template MakeGridDescriptor<padM, padK>(base_desc, M, M, K, K, 0, AK0);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeBsGridDescriptor_BK0_N_BK1(const index_t K,
|
||||
const index_t KPad,
|
||||
@@ -516,6 +565,19 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
Number<NumBTensor>{});
|
||||
}
|
||||
|
||||
template <typename GridDescBase>
|
||||
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(const GridDescBase& base_desc)
|
||||
{
|
||||
const auto N = base_desc.GetLength(I0);
|
||||
const auto K = base_desc.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / BK1Value;
|
||||
|
||||
constexpr bool padN = false;
|
||||
constexpr bool padK = false;
|
||||
return BTransfer::template MakeGridDescriptor<padN, padK>(base_desc, N, N, K, K, 0, BK0);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeAWmmaTileDescriptor()
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
@@ -594,8 +656,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
#endif
|
||||
}
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto MakeDsGridPointer()
|
||||
{
|
||||
return generate_tuple(
|
||||
@@ -679,6 +739,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>;
|
||||
|
||||
using EpilogueDirectStore = EpilogueDirectStore<DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CDEElementwiseOperation,
|
||||
BlockwiseGemmPipe>;
|
||||
|
||||
using EpilogueWelfordCShuffle = EpilogueWelfordCShuffle<
|
||||
DsDataType,
|
||||
EDataType,
|
||||
@@ -1000,18 +1068,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
max_lds_align)
|
||||
: 0;
|
||||
|
||||
// LDS allocation for C shuffle in LDS
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
EpilogueType::
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
if constexpr(EpilogueType::IsLDSNeeded())
|
||||
{
|
||||
// LDS allocation for C shuffle in LDS
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
EpilogueType::
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
|
||||
.GetElementSpaceSize();
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
|
||||
.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize +
|
||||
b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize),
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize +
|
||||
b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize),
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
else
|
||||
{
|
||||
return a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize +
|
||||
b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t numElements, typename Type>
|
||||
@@ -1148,7 +1224,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
num_k_block_main_loop,
|
||||
num_k_block_per_scale);
|
||||
|
||||
// shuffle C and write out
|
||||
// Epilogue:
|
||||
// - CShuffle / direct store
|
||||
// - Multiple Ds
|
||||
// - Fused operations
|
||||
epilogue_args.template Run<EGlobalMemoryDataOperation>(
|
||||
c_thread_buf,
|
||||
p_ds_grid,
|
||||
|
||||
Reference in New Issue
Block a user