mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa
This commit is contained in:
@@ -74,7 +74,10 @@ template <typename GridwiseGemm,
|
||||
typename CDEElementwiseOp,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
typename ComputePtrOffsetOfN,
|
||||
InMemoryDataOperationEnum OutElementOp>
|
||||
InMemoryDataOperationEnum OutElementOp,
|
||||
bool HasMainKBlockLoopInAllGemm,
|
||||
bool NoMainKBlockLoopInAllGemm,
|
||||
bool CTranspose>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -101,16 +104,21 @@ __global__ void
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))
|
||||
: amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))
|
||||
: amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
const long_index_t a_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
const long_index_t b_n_offset =
|
||||
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
|
||||
|
||||
const long_index_t e_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
|
||||
|
||||
@@ -141,11 +149,11 @@ __global__ void
|
||||
group_id = index_t((left + right) / 2);
|
||||
}
|
||||
|
||||
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
|
||||
if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm)
|
||||
{
|
||||
GridwiseGemm::template Run<true, OutElementOp>(
|
||||
GridwiseGemm::template Run<HasMainKBlockLoopInAllGemm, OutElementOp>(
|
||||
p_a_grid + a_batch_offset + a_n_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_b_grid + b_batch_offset + b_n_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
@@ -162,22 +170,44 @@ __global__ void
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<false, OutElementOp>(
|
||||
p_a_grid + a_batch_offset + a_n_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_kernel_args[group_id].block_2_ctile_map_,
|
||||
KBatch,
|
||||
k_idx);
|
||||
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
|
||||
{
|
||||
GridwiseGemm::template Run<true, OutElementOp>(
|
||||
p_a_grid + a_batch_offset + a_n_offset,
|
||||
p_b_grid + b_batch_offset + b_n_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_kernel_args[group_id].block_2_ctile_map_,
|
||||
KBatch,
|
||||
k_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<false, OutElementOp>(
|
||||
p_a_grid + a_batch_offset + a_n_offset,
|
||||
p_b_grid + b_batch_offset + b_n_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_kernel_args[group_id].block_2_ctile_map_,
|
||||
KBatch,
|
||||
k_idx);
|
||||
}
|
||||
}
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
@@ -278,7 +308,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
// implementation we can avoid copy data to workspace before kernel launch since number of
|
||||
// groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then
|
||||
// we run this kernel in the loop.
|
||||
static constexpr index_t MaxGroupedGemmGroupsNum = 32;
|
||||
static constexpr index_t MaxGroupedGemmGroupsNum =
|
||||
ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0
|
||||
? 1
|
||||
: 32;
|
||||
|
||||
using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
|
||||
|
||||
@@ -296,24 +330,40 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using ALayoutAfterTranspose =
|
||||
std::conditional_t<is_NGCHW_NGKHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::NHWGK,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::NDHWGK,
|
||||
ALayout>>;
|
||||
using BLayoutAfterTranspose =
|
||||
std::conditional_t<is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::GKYXC,
|
||||
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::GKZYXC,
|
||||
BLayout>>;
|
||||
using ELayoutAfterTranspose =
|
||||
std::conditional_t<is_NGCHW_NGKHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::NHWGC,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::NDHWGC,
|
||||
ELayout>>;
|
||||
static constexpr bool isATensorColMajor =
|
||||
(ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) &&
|
||||
(ABlockTransferSrcVectorDim == 1) &&
|
||||
(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>());
|
||||
|
||||
static constexpr bool NeedTransposeKernel =
|
||||
(isATensorColMajor == false) && (is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>());
|
||||
|
||||
static constexpr bool CTranspose =
|
||||
(NeedTransposeKernel == false) && (is_same_v<ELayout, tensor_layout::convolution::NGCHW> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NGCDHW>);
|
||||
|
||||
using ALayoutAfterTranspose = std::conditional_t<
|
||||
is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
|
||||
tensor_layout::convolution::NHWGK,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
|
||||
tensor_layout::convolution::NDHWGK,
|
||||
ALayout>>;
|
||||
using BLayoutAfterTranspose = std::conditional_t<
|
||||
is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
|
||||
tensor_layout::convolution::GKYXC,
|
||||
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>() &&
|
||||
NeedTransposeKernel,
|
||||
tensor_layout::convolution::GKZYXC,
|
||||
BLayout>>;
|
||||
using ELayoutAfterTranspose = std::conditional_t<
|
||||
is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
|
||||
tensor_layout::convolution::NHWGC,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
|
||||
tensor_layout::convolution::NDHWGC,
|
||||
ELayout>>;
|
||||
|
||||
using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
@@ -329,7 +379,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
ELayoutAfterTranspose,
|
||||
true, /*SplitConvN*/
|
||||
ABDataType,
|
||||
EDataType>;
|
||||
EDataType,
|
||||
1,
|
||||
index_t,
|
||||
CTranspose>;
|
||||
|
||||
static auto
|
||||
GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
|
||||
@@ -357,15 +410,25 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
DLayout,
|
||||
true, /*SplitConvN*/
|
||||
ABDataType,
|
||||
DDataType>;
|
||||
DDataType,
|
||||
1, /*index_t NumGroupsToMerge = 1,*/
|
||||
index_t, /* typename IndexType = */
|
||||
CTranspose>;
|
||||
return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N();
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
|
||||
|
||||
return make_tuple(
|
||||
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return make_tuple(
|
||||
b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1, ds_grid_desc_m_n, e_grid_desc_m_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(
|
||||
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
|
||||
}
|
||||
}
|
||||
|
||||
// GridwiseGemm
|
||||
@@ -383,13 +446,34 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
|
||||
|
||||
#define GridwiseGemmCTransposeTemplateParameters \
|
||||
ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
|
||||
BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \
|
||||
NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave, MXdlPerWave, \
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
|
||||
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
|
||||
BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
|
||||
BBlockLdsExtraN, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
|
||||
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
|
||||
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
|
||||
ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
|
||||
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmMultiDTemplateParams>;
|
||||
using GridwiseGemmCTranspose = std::conditional_t<
|
||||
CTranspose,
|
||||
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>,
|
||||
GridwiseGemm>;
|
||||
|
||||
template <typename EGridDesc_M_N>
|
||||
static auto
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N e_grid_desc_m_n)
|
||||
{
|
||||
return GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
|
||||
return GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n);
|
||||
}
|
||||
|
||||
template <typename Desc_K0_M_K1>
|
||||
@@ -419,13 +503,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{}));
|
||||
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}));
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
using Block2ETileMap =
|
||||
decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}));
|
||||
|
||||
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
|
||||
|
||||
@@ -630,14 +715,17 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
sizeof(EDataType);
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
a_g_n_k_wos_lengths, a_g_n_k_wos_strides)
|
||||
: a_g_n_k_wos_strides;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides);
|
||||
NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
|
||||
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)
|
||||
: b_g_k_c_xs_strides;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides);
|
||||
NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
e_g_n_c_wis_lengths, e_g_n_c_wis_strides)
|
||||
: e_g_n_c_wis_strides;
|
||||
|
||||
// populate Ds pointer
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
@@ -737,12 +825,27 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
conv_N_per_block_ = conv_to_gemm_transform_.N_;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
|
||||
const auto a_grid_desc_ak0_m_ak1 = [&]() {
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
|
||||
}
|
||||
else
|
||||
{
|
||||
return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 = [&]() {
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
|
||||
}
|
||||
else
|
||||
{
|
||||
return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
|
||||
}
|
||||
}();
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
|
||||
// populate Ds desc
|
||||
@@ -764,7 +867,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
DLayout,
|
||||
true, /*SplitConvN*/
|
||||
ABDataType,
|
||||
DDataType>;
|
||||
DDataType,
|
||||
1,
|
||||
index_t,
|
||||
CTranspose>;
|
||||
ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
@@ -810,14 +916,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
const auto GemmK = a_grid_desc_m_k.GetLength(I1);
|
||||
const bool HasMainKBlockLoop =
|
||||
GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, k_batch_);
|
||||
GridwiseGemmCTranspose::CalculateHasMainKBlockLoop(GemmK, k_batch_);
|
||||
|
||||
gemm_kernel_args_[gemms_count_ /
|
||||
MaxGroupedGemmGroupsNum][gemms_count_ %
|
||||
MaxGroupedGemmGroupsNum] =
|
||||
GemmArgs{a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
GridwiseGemm::
|
||||
GridwiseGemmCTranspose::
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n),
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -851,8 +957,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_;
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
// Use not modified base strides
|
||||
a_in_transpose_desc_ =
|
||||
@@ -892,8 +997,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
|
||||
a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
@@ -908,8 +1012,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
std::size_t GetWorkspaceBTensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
|
||||
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
@@ -924,8 +1027,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
|
||||
e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
@@ -1030,24 +1132,25 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
const BDataType* p_b_grid = arg.p_b_grid_;
|
||||
EDataType* p_e_grid = arg.p_e_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
}
|
||||
}
|
||||
|
||||
for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size();
|
||||
gemm_set_id++)
|
||||
{
|
||||
@@ -1067,42 +1170,111 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
};
|
||||
|
||||
auto launch_kernel = [&]() {
|
||||
const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
MaxGroupedGemmGroupsNum,
|
||||
GemmArgs,
|
||||
AElementwiseOp,
|
||||
BElementwiseOp,
|
||||
CDEElementwiseOp,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
ElementOp>;
|
||||
bool has_loop_in_all_gemm = true;
|
||||
bool no_loop_in_all_gemm = true;
|
||||
for(auto i = 0; i < gemms_count_for_set; i++)
|
||||
{
|
||||
has_loop_in_all_gemm &= gemm_kernel_args[i].HasMainKBlockLoop_;
|
||||
no_loop_in_all_gemm &= !gemm_kernel_args[i].HasMainKBlockLoop_;
|
||||
}
|
||||
|
||||
return launch_and_time_kernel_with_preprocess(stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
gemm_kernel_args,
|
||||
gemms_count_for_set,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.k_batch_);
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop, auto no_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
constexpr bool no_main_loop = no_main_k_block_loop.value;
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
|
||||
GridwiseGemmCTranspose,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
MaxGroupedGemmGroupsNum,
|
||||
GemmArgs,
|
||||
BElementwiseOp,
|
||||
AElementwiseOp,
|
||||
CDEElementwiseOp,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
ElementOp,
|
||||
has_main_loop,
|
||||
no_main_loop,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_b_grid,
|
||||
p_a_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
gemm_kernel_args,
|
||||
gemms_count_for_set,
|
||||
arg.b_element_op_,
|
||||
arg.a_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.k_batch_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
MaxGroupedGemmGroupsNum,
|
||||
GemmArgs,
|
||||
AElementwiseOp,
|
||||
BElementwiseOp,
|
||||
CDEElementwiseOp,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
ElementOp,
|
||||
has_main_loop,
|
||||
no_main_loop,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
gemm_kernel_args,
|
||||
gemms_count_for_set,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.k_batch_);
|
||||
}
|
||||
};
|
||||
|
||||
ave_time += launch_kernel();
|
||||
if(has_loop_in_all_gemm)
|
||||
{
|
||||
ave_time += launch_kernel(integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
else if(no_loop_in_all_gemm)
|
||||
{
|
||||
ave_time += launch_kernel(integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time += launch_kernel(integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
@@ -1116,9 +1288,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
// Transpose from NGKHW to NHWGK
|
||||
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
EDataType* p_e_in_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
@@ -1208,8 +1380,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
|
||||
// Transpose from NHWGC to NGCHW
|
||||
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
@@ -1284,10 +1455,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
}
|
||||
|
||||
const index_t ConvG = arg.b_g_k_c_xs_lengths_[0];
|
||||
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
|
||||
const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
|
||||
|
||||
const index_t ConvG = arg.b_g_k_c_xs_lengths_[0];
|
||||
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
|
||||
const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
// Specifialization
|
||||
if constexpr(ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
|
||||
@@ -1307,15 +1481,30 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NGKHW> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NGKDHW>)
|
||||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK> || NeedTransposeKernel)
|
||||
{
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if(is_same_v<ALayout, tensor_layout::convolution::NGKHW> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NGKDHW>)
|
||||
{
|
||||
static_assert(NeedTransposeKernel == false);
|
||||
|
||||
if constexpr(ABlockTransferSrcScalarPerVector != 1)
|
||||
{
|
||||
if(ABlockTransferSrcVectorDim != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(output_spatial_acum % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
@@ -1351,10 +1540,20 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
is_same_v<DLayout, tensor_layout::convolution::GC> ||
|
||||
is_same_v<DLayout, tensor_layout::convolution::G_C>)
|
||||
{
|
||||
// vector load D matrix from global memory
|
||||
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
if(CTranspose == false)
|
||||
{
|
||||
ds_valid = false;
|
||||
// vector load D matrix from global memory
|
||||
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
ds_valid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
ds_valid = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -1376,10 +1575,20 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
is_same_v<ELayout, tensor_layout::convolution::NGCHW> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NGCDHW>)
|
||||
{
|
||||
// vector store C matrix into global memory
|
||||
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
if(CTranspose == false)
|
||||
{
|
||||
return false;
|
||||
// vector store C matrix into global memory
|
||||
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -1390,7 +1599,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
// Gridwise GEMM size
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
if(!GridwiseGemmCTranspose::CheckValidity(
|
||||
arg.a_grid_desc_m_k_container_[i],
|
||||
arg.b_grid_desc_n_k_container_[i],
|
||||
arg.ds_grid_desc_m_n_container_[i],
|
||||
@@ -1403,8 +1612,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
if((ConvG * ConvC) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
|
||||
@@ -30,7 +30,8 @@ template <
|
||||
typename ADataType = float,
|
||||
typename CDataType = float,
|
||||
index_t NumGroupsToMerge = 1,
|
||||
typename IndexType = index_t>
|
||||
typename IndexType = index_t,
|
||||
bool CTranspose = false>
|
||||
struct TransformConvBwdDataToGemm_v1
|
||||
{
|
||||
private:
|
||||
@@ -555,6 +556,41 @@ struct TransformConvBwdDataToGemm_v1
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(N_, Do_, Ho_, Wo_, K_));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::NGKHW>)
|
||||
{
|
||||
// assume packed
|
||||
static_assert(ConvBwdDataSpecialization ==
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0);
|
||||
|
||||
const auto out_gemm_raw_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Ho_ * Wo_, K_), make_tuple(NStrideTensorA_, I1, KStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
out_gemm_raw_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::NGKDHW>)
|
||||
{
|
||||
// assume packed
|
||||
static_assert(ConvBwdDataSpecialization ==
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0);
|
||||
|
||||
const auto out_gemm_raw_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N_, Do_ * Ho_ * Wo_, K_),
|
||||
make_tuple(NStrideTensorA_, I1, KStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
out_gemm_raw_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
|
||||
@@ -608,7 +644,9 @@ struct TransformConvBwdDataToGemm_v1
|
||||
(is_same_v<ALayout_, tensor_layout::convolution::GNHWK> ||
|
||||
is_same_v<ALayout_, tensor_layout::convolution::GNDHWK> ||
|
||||
is_same_v<ALayout_, tensor_layout::convolution::NHWGK> ||
|
||||
is_same_v<ALayout_, tensor_layout::convolution::NDHWGK>),
|
||||
is_same_v<ALayout_, tensor_layout::convolution::NDHWGK> ||
|
||||
is_same_v<ALayout_, tensor_layout::convolution::NGKHW> ||
|
||||
is_same_v<ALayout_, tensor_layout::convolution::NGKDHW>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeADescriptor_AK0_M_AK1() const
|
||||
{
|
||||
@@ -848,16 +886,16 @@ struct TransformConvBwdDataToGemm_v1
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BLayout_ = BLayout,
|
||||
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
|
||||
(is_same_v<BLayout_, tensor_layout::convolution::GKYXC> ||
|
||||
is_same_v<BLayout_, tensor_layout::convolution::GKZYXC>),
|
||||
bool>::type = false>
|
||||
template <
|
||||
typename BLayout_ = BLayout,
|
||||
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
|
||||
(is_same_v<BLayout_, tensor_layout::convolution::GKYXC> ||
|
||||
is_same_v<BLayout_, tensor_layout::convolution::GKZYXC> ||
|
||||
is_same_v<BLayout_, tensor_layout::convolution::GKCYX> ||
|
||||
is_same_v<BLayout_, tensor_layout::convolution::GKCZYX>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const
|
||||
{
|
||||
// assume packed
|
||||
// k_y_x_c for 2d or k_z_y_x_c for 3d
|
||||
const auto wei_grid_desc = MakeWeiGridDesc();
|
||||
|
||||
if constexpr(ConvBwdDataSpecialization ==
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
@@ -886,6 +924,12 @@ struct TransformConvBwdDataToGemm_v1
|
||||
}
|
||||
else
|
||||
{
|
||||
// assume packed
|
||||
// k_y_x_c for 2d or k_z_y_x_c for 3d
|
||||
static_assert(is_same_v<BLayout_, tensor_layout::convolution::GKYXC> ||
|
||||
is_same_v<BLayout_, tensor_layout::convolution::GKZYXC>);
|
||||
const auto wei_grid_desc = MakeWeiGridDesc();
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_);
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_);
|
||||
@@ -1059,6 +1103,7 @@ struct TransformConvBwdDataToGemm_v1
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
static_assert(CTranspose == false);
|
||||
// assume strided
|
||||
// n_hi_wi_c for 2d n_di_hi_wi_c for 3d
|
||||
const auto in_grid_desc = MakeInGridDesc();
|
||||
@@ -1314,6 +1359,48 @@ struct TransformConvBwdDataToGemm_v1
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CLayout_ = CLayout,
|
||||
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
|
||||
(is_same_v<CLayout_, tensor_layout::convolution::NGCHW> ||
|
||||
is_same_v<CLayout_, tensor_layout::convolution::NGCDHW>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
const auto in_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, C_, Di_ * Hi_ * Wi_), make_tuple(NStrideTensorC_, CStrideTensorC_, I1));
|
||||
|
||||
static_assert(ConvBwdDataSpecialization ==
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0);
|
||||
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(C_),
|
||||
make_merge_transform(make_tuple(N_, Di_ * Hi_ * Wi_))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return ck::tensor_operation::device::PadTensorDescriptor(
|
||||
in_gemmmraw_gemmnraw_grid_desc,
|
||||
make_tuple(GemmNPerBlock, GemmMPerBlock),
|
||||
Sequence<DoPadGemmN, DoPadGemmM>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Di_ * Hi_ * Wi_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return ck::tensor_operation::device::PadTensorDescriptor(
|
||||
in_gemmmraw_gemmnraw_grid_desc,
|
||||
make_tuple(GemmMPerBlock, GemmNPerBlock),
|
||||
Sequence<DoPadGemmM, DoPadGemmN>{});
|
||||
}
|
||||
}
|
||||
// for input bias
|
||||
template <typename CLayout_ = CLayout,
|
||||
typename std::enable_if<NDimSpatial == 2 &&
|
||||
@@ -1326,14 +1413,26 @@ struct TransformConvBwdDataToGemm_v1
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0)
|
||||
{
|
||||
const auto in_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
const auto in_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(C_, N_ * Ho_ * Wo_), make_tuple(I1, I0));
|
||||
|
||||
return in_gemmm_gemmn_grid_desc;
|
||||
return in_gemmm_gemmn_grid_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
|
||||
|
||||
return in_gemmm_gemmn_grid_desc;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
|
||||
static_assert(CTranspose == false);
|
||||
// only work on HTilde and WTilde that contribute to non-padding area of input
|
||||
// tensor
|
||||
const auto IHTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_);
|
||||
const auto IWTildeSliceBegin = math::integer_divide_floor(
|
||||
|
||||
@@ -66,6 +66,7 @@
|
||||
#include "ck_tile/core/tensor/transpose_tile.hpp"
|
||||
#include "ck_tile/core/tensor/update_tile.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/debug.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/functional_with_tuple.hpp"
|
||||
|
||||
@@ -2852,11 +2852,13 @@ __device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
|
||||
std::is_same_v<remove_cvref_t<T>, ck_tile::bf8_t> ||
|
||||
std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
|
||||
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
|
||||
__attribute__((address_space(3))) llvm_i32x2_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
|
||||
}
|
||||
|
||||
@@ -2622,11 +2622,13 @@ __device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
|
||||
std::is_same_v<remove_cvref_t<T>, ck_tile::bf8_t> ||
|
||||
std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
|
||||
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
|
||||
__attribute__((address_space(3))) llvm_i32x2_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
|
||||
}
|
||||
|
||||
@@ -10,53 +10,55 @@
|
||||
namespace ck_tile {
|
||||
|
||||
// this generate wave level tile distribution
|
||||
template <typename T, typename = void>
|
||||
template <typename T, index_t LaneGroupSize = 16, typename = void>
|
||||
struct LaneGroupTransposeTraits;
|
||||
|
||||
template <typename T>
|
||||
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 2>>
|
||||
template <typename T, index_t LaneGroupSize>
|
||||
struct LaneGroupTransposeTraits<T, LaneGroupSize, std::enable_if_t<sizeof(T) == 2>>
|
||||
{
|
||||
static_assert(LaneGroupSize == 16 || LaneGroupSize == 32 || LaneGroupSize == 64,
|
||||
"LaneGroupSize must be 16, 32, or 64");
|
||||
// before transpose, 4x16
|
||||
static constexpr index_t ksecondDim = 4;
|
||||
static constexpr index_t kleadDim = 16;
|
||||
static constexpr index_t kleadDim = LaneGroupSize;
|
||||
// after transpose, 16x4
|
||||
static constexpr index_t ksecondDimT = 16;
|
||||
static constexpr index_t ksecondDimT = LaneGroupSize;
|
||||
static constexpr index_t kleadDimT = 4;
|
||||
template <index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
using TileDistribution =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, 4, 4>>,
|
||||
tuple<sequence<1, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 3>>;
|
||||
using TileDistribution = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, LaneGroupSize / 16, 4, 4>>,
|
||||
tuple<sequence<1, 2, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2, 3>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 4>>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 1>>
|
||||
template <typename T, index_t LaneGroupSize>
|
||||
struct LaneGroupTransposeTraits<T, LaneGroupSize, std::enable_if_t<sizeof(T) == 1>>
|
||||
{
|
||||
static constexpr index_t ksecondDim = 8;
|
||||
static constexpr index_t kleadDim = 16;
|
||||
static constexpr index_t kleadDim = LaneGroupSize;
|
||||
|
||||
static constexpr index_t ksecondDimT = 16;
|
||||
static constexpr index_t ksecondDimT = LaneGroupSize;
|
||||
static constexpr index_t kleadDimT = 8;
|
||||
|
||||
template <index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
using TileDistribution =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, 2, 8>>,
|
||||
tuple<sequence<1, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 3>>;
|
||||
using TileDistribution = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, LaneGroupSize / 16, 2, 8>>,
|
||||
tuple<sequence<1, 2, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2, 3>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 4>>;
|
||||
};
|
||||
|
||||
/*
|
||||
@@ -72,15 +74,15 @@ struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 1>>
|
||||
* consecutive.
|
||||
*/
|
||||
template <typename T,
|
||||
index_t LaneGroupSize,
|
||||
index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
CK_TILE_DEVICE constexpr auto make_transposed_distr_encode()
|
||||
{
|
||||
using xdllevel_dstr_encoding = typename LaneGroupTransposeTraits<T>::
|
||||
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>;
|
||||
return xdllevel_dstr_encoding{};
|
||||
return typename LaneGroupTransposeTraits<T, LaneGroupSize>::
|
||||
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>{};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -994,51 +994,34 @@ struct buffer_view<address_space_enum::lds,
|
||||
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
|
||||
// ds_write_b128
|
||||
// TODO: remove this after compiler fix
|
||||
// clang-format off
|
||||
static_assert(
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x4_t> && std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x8_t> && std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
// int8 on thread buffer
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
|
||||
// ext_vector_type for pk_int4 must use int8_t as type
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
|
||||
"wrong! not implemented for this combination, please add "
|
||||
"implementation");
|
||||
// clang-format on
|
||||
|
||||
if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
|
||||
@@ -1090,6 +1073,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>))
|
||||
{
|
||||
|
||||
@@ -17,6 +17,11 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
constexpr int DS_READ_TR_SIZE()
|
||||
{
|
||||
return 8; // Literal constant, evaluated at compile time
|
||||
}
|
||||
|
||||
namespace util {
|
||||
template <typename Suffix, typename Sequence>
|
||||
struct is_sequence_suffix
|
||||
@@ -45,48 +50,60 @@ constexpr bool is_sequence_suffix_v = is_sequence_suffix<Suffix, Sequence>::valu
|
||||
template <typename DataType>
|
||||
struct DefaultTranspose
|
||||
{
|
||||
template <index_t LaneGroupSize>
|
||||
struct Quad16
|
||||
{
|
||||
using InputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<4>, sequence<4, 4>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
|
||||
"LaneGroupSize must be 64, 32, or 16");
|
||||
using InputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<4>, sequence<LaneGroupSize / 16, 4, 4>>,
|
||||
tuple<sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>;
|
||||
|
||||
using OutputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<16>, sequence<4>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
using OutputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<LaneGroupSize>, sequence<4>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
};
|
||||
|
||||
template <index_t LaneGroupSize>
|
||||
struct Quad8
|
||||
{
|
||||
using InputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<8>, sequence<2, 8>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
|
||||
"LaneGroupSize must be 64, 32, or 16");
|
||||
using InputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<8>, sequence<LaneGroupSize / 16, 2, 8>>,
|
||||
tuple<sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>;
|
||||
|
||||
using OutputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<16>, sequence<8>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
using OutputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<LaneGroupSize>, sequence<8>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
};
|
||||
|
||||
// Select based on data size
|
||||
template <index_t LaneGroupSize>
|
||||
using QuadInputEncoding = std::conditional_t<sizeof(DataType) == 2,
|
||||
typename Quad16::InputEncoding,
|
||||
typename Quad8::InputEncoding>;
|
||||
typename Quad16<LaneGroupSize>::InputEncoding,
|
||||
typename Quad8<LaneGroupSize>::InputEncoding>;
|
||||
|
||||
template <index_t LaneGroupSize>
|
||||
using QuadOutputEncoding = std::conditional_t<sizeof(DataType) == 2,
|
||||
typename Quad16::OutputEncoding,
|
||||
typename Quad8::OutputEncoding>;
|
||||
typename Quad16<LaneGroupSize>::OutputEncoding,
|
||||
typename Quad8<LaneGroupSize>::OutputEncoding>;
|
||||
|
||||
// Always swap last two dimensions
|
||||
static constexpr auto transpose_dims = sequence<1, 0>{};
|
||||
@@ -96,51 +113,79 @@ struct DefaultTranspose
|
||||
return idx; // Identity mapping
|
||||
};
|
||||
|
||||
template <typename InDstrEncode>
|
||||
struct ValidationTraits
|
||||
template <typename InDstrEncode, bool ReverseDirection, index_t LaneGroupSize>
|
||||
struct ValidationTraitsImpl
|
||||
{
|
||||
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
|
||||
static constexpr auto quad_hs_lengthss = QuadInputEncoding::hs_lengthss_;
|
||||
using QuadEncoding = std::conditional_t<ReverseDirection,
|
||||
QuadOutputEncoding<LaneGroupSize>,
|
||||
QuadInputEncoding<LaneGroupSize>>;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto input_hs = InDstrEncode::hs_lengthss_;
|
||||
static constexpr auto quad_hs = QuadEncoding::hs_lengthss_;
|
||||
// 1. Must be 2D tensor
|
||||
static constexpr bool dims_valid = (InDstrEncode::NDimX == 2);
|
||||
// 2. Quad pattern must be suffix of input pattern
|
||||
static constexpr bool suffix_valid_dim0 =
|
||||
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<0>()),
|
||||
decltype(input_hs_lengthss.template get<0>())>;
|
||||
util::is_sequence_suffix_v<decltype(quad_hs[I0]), decltype(input_hs[I0])>;
|
||||
static constexpr bool suffix_valid_dim1 =
|
||||
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<1>()),
|
||||
decltype(input_hs_lengthss.template get<1>())>;
|
||||
util::is_sequence_suffix_v<decltype(quad_hs[I1]), decltype(input_hs[I1])>;
|
||||
|
||||
// 3. PS→RHS mapping constraints
|
||||
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
|
||||
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
|
||||
static constexpr auto input_ps_major = InDstrEncode::ps_to_rhss_major_;
|
||||
static constexpr auto input_ps_minor = InDstrEncode::ps_to_rhss_minor_;
|
||||
|
||||
static constexpr index_t ndimp_outer = input_ps_to_rhss_major.size() - 1;
|
||||
static constexpr index_t ndimp_inner =
|
||||
input_ps_to_rhss_major[number<ndimp_outer>{}].size() - 1;
|
||||
static constexpr auto quad_ps_major0 = QuadEncoding::ps_to_rhss_major_[I0];
|
||||
static constexpr auto quad_ps_minor0 = QuadEncoding::ps_to_rhss_minor_[I0];
|
||||
|
||||
static constexpr auto input_ps_major_last =
|
||||
input_ps_major[number<input_ps_major.size() - 1>{}];
|
||||
static constexpr auto input_ps_minor_last =
|
||||
input_ps_minor[number<input_ps_minor.size() - 1>{}];
|
||||
|
||||
using psys_offset = ck_tile::sequence<input_hs[I0].size() - quad_hs[I0].size(),
|
||||
input_hs[I1].size() - quad_hs[I1].size()>;
|
||||
static constexpr auto shifted_quad_ps_minor0 = generate_sequence_v2(
|
||||
[](auto i) {
|
||||
return number<quad_ps_minor0[i] + psys_offset{}[quad_ps_major0[i] - 1]>{};
|
||||
},
|
||||
number<quad_ps_minor0.size()>{});
|
||||
|
||||
static constexpr bool ps_mapping_valid =
|
||||
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner>{}] == 2) &&
|
||||
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner>{}] ==
|
||||
input_hs_lengthss[number<1>{}].size() - 2) &&
|
||||
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] == 1) &&
|
||||
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] ==
|
||||
input_hs_lengthss[number<0>{}].size() - 1);
|
||||
util::is_sequence_suffix_v<decltype(quad_ps_major0), decltype(input_ps_major_last)> &&
|
||||
util::is_sequence_suffix_v<decltype(shifted_quad_ps_minor0),
|
||||
decltype(input_ps_minor_last)>;
|
||||
|
||||
// 4. YS→RHS mapping constraints
|
||||
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
|
||||
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
static constexpr auto input_ys_major = InDstrEncode::ys_to_rhs_major_;
|
||||
static constexpr auto input_ys_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
static constexpr auto quad_ys_major = QuadEncoding::ys_to_rhs_major_;
|
||||
static constexpr auto quad_ys_minor = QuadEncoding::ys_to_rhs_minor_;
|
||||
|
||||
static_assert(quad_ys_major.size() == 1 && quad_ys_minor.size() == 1,
|
||||
"YS->RHS mapping must be single dimension");
|
||||
static_assert(quad_ys_major.back() == 2 && quad_ys_minor.back() == quad_hs[I1].size() - 1,
|
||||
"YS->RHS mapping must be the last dimension");
|
||||
static constexpr bool ys_mapping_valid =
|
||||
(input_ys_to_rhs_major.back() == 2) &&
|
||||
(input_ys_to_rhs_minor.back() == input_hs_lengthss[number<1>{}].size() - 1) &&
|
||||
(input_ys_to_rhs_major[input_ys_to_rhs_major.size() - 2] == 1) &&
|
||||
(input_ys_to_rhs_minor[input_ys_to_rhs_minor.size() - 2] ==
|
||||
input_hs_lengthss[number<0>{}].size() - 2);
|
||||
(input_ys_major.back() == 2) && (input_ys_minor.back() == input_hs[I1].size() - 1);
|
||||
|
||||
static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 &&
|
||||
ps_mapping_valid && ys_mapping_valid;
|
||||
};
|
||||
|
||||
template <typename InDstrEncode, bool ReverseDirection = false>
|
||||
struct ValidationTraits
|
||||
{
|
||||
static constexpr bool value =
|
||||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 64>::value ||
|
||||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 32>::value ||
|
||||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 16>::value;
|
||||
static constexpr index_t LaneGroupSize =
|
||||
ValidationTraitsImpl<InDstrEncode, ReverseDirection, 64>::value ? 64
|
||||
: ValidationTraitsImpl<InDstrEncode, ReverseDirection, 32>::value ? 32
|
||||
: ValidationTraitsImpl<InDstrEncode, ReverseDirection, 16>::value ? 16
|
||||
: 0;
|
||||
};
|
||||
};
|
||||
template <typename TileDistribution_, typename DataType_, typename Policy>
|
||||
struct TransposeTileDistrChecker
|
||||
@@ -154,111 +199,150 @@ struct TransposeTileDistrChecker
|
||||
|
||||
// this is used to generate the transposed output tile distribution encoding
|
||||
// based on the input tile distribution encoding
|
||||
template <typename TileDistribution_,
|
||||
template <typename TileDistributionEncoding_,
|
||||
typename DataType_,
|
||||
typename Policy = DefaultTranspose<DataType_>>
|
||||
struct OutputTileDistributionTraits
|
||||
typename Policy = DefaultTranspose<DataType_>,
|
||||
bool ReverseDirection = false>
|
||||
struct TransposeTileDistributionTraits
|
||||
{
|
||||
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
|
||||
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
|
||||
static constexpr auto quad_input_hs_lengthss = Policy::QuadInputEncoding::hs_lengthss_;
|
||||
static constexpr auto quad_output_hs_lengthss = Policy::QuadOutputEncoding::hs_lengthss_;
|
||||
using InDstrEncode = remove_cvref_t<TileDistributionEncoding_>;
|
||||
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
|
||||
static constexpr index_t LaneGroupSize =
|
||||
Policy::template ValidationTraits<InDstrEncode, ReverseDirection>::LaneGroupSize;
|
||||
static_assert(Policy::template ValidationTraits<InDstrEncode, ReverseDirection>::value,
|
||||
"The input tile distribution encoding is not valid for transpose!");
|
||||
|
||||
using QuadInputEncoding = std::conditional_t< //
|
||||
ReverseDirection,
|
||||
typename Policy::template QuadOutputEncoding<LaneGroupSize>,
|
||||
typename Policy::template QuadInputEncoding<LaneGroupSize>>;
|
||||
using QuadOutputEncoding = std::conditional_t< //
|
||||
ReverseDirection,
|
||||
typename Policy::template QuadInputEncoding<LaneGroupSize>,
|
||||
typename Policy::template QuadOutputEncoding<LaneGroupSize>>;
|
||||
|
||||
static constexpr auto quad_input_hs_lengthss = QuadInputEncoding::hs_lengthss_;
|
||||
static constexpr auto quad_output_hs_lengthss = QuadOutputEncoding::hs_lengthss_;
|
||||
|
||||
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
|
||||
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
|
||||
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
|
||||
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
|
||||
static constexpr auto quad_ps_to_rhss_major = Policy::QuadInputEncoding::ps_to_rhss_major_;
|
||||
static constexpr auto quad_ps_to_rhss_minor = Policy::QuadInputEncoding::ps_to_rhss_minor_;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto quad_input_ps_to_rhss_major0 = QuadInputEncoding::ps_to_rhss_major_[I0];
|
||||
static constexpr auto quad_input_ps_to_rhss_minor0 = QuadInputEncoding::ps_to_rhss_minor_[I0];
|
||||
static constexpr auto quad_output_ps_to_rhss_major0 = QuadOutputEncoding::ps_to_rhss_major_[I0];
|
||||
static constexpr auto quad_output_ps_to_rhss_minor0 = QuadOutputEncoding::ps_to_rhss_minor_[I0];
|
||||
static constexpr auto quad_output_ys_to_rhs_major = QuadOutputEncoding::ys_to_rhs_major_;
|
||||
static constexpr auto quad_output_ys_to_rhs_minor = QuadOutputEncoding::ys_to_rhs_minor_;
|
||||
|
||||
static constexpr index_t dim0 = Policy::transpose_dims[0];
|
||||
static constexpr index_t dim1 = Policy::transpose_dims[1];
|
||||
|
||||
static constexpr auto swap_one_and_two = [](const index_t idx) {
|
||||
return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
|
||||
};
|
||||
|
||||
// for transpose load
|
||||
// append the reversed quad output hs lengths to the input hs lengthss after removing
|
||||
// the quad_input_hs_lengthss
|
||||
// then reverse the whole sequence to get the dst_out_hs_lengthss
|
||||
static constexpr auto reversed_quad_output_hs_lengthss = tuple_reverse(quad_output_hs_lengthss);
|
||||
|
||||
static constexpr auto full_out_hs_lengthss = generate_tuple(
|
||||
// remove the quad_input_hs_lengthss from the input_hs_lengthss for each dimension and reverse
|
||||
// dims and append the quad_output_hs_lengthss to the end of each dimension
|
||||
static constexpr auto outer_hs_lengthss = generate_tuple(
|
||||
[](auto i) {
|
||||
return input_hs_lengthss[i]
|
||||
.extract(typename arithmetic_sequence_gen<0,
|
||||
input_hs_lengthss[i].size() -
|
||||
quad_input_hs_lengthss[i].size(),
|
||||
1>::type{})
|
||||
.push_back(reversed_quad_output_hs_lengthss[i]);
|
||||
constexpr auto input_i = input_hs_lengthss[i];
|
||||
constexpr auto outer_len = input_i.size() - quad_input_hs_lengthss[i].size();
|
||||
return typename sequence_split<decltype(input_i), outer_len>::left_type{};
|
||||
},
|
||||
number<InDstrEncode::NDimX>{});
|
||||
static constexpr auto reversed_outer_hs_lengthss = tuple_reverse(outer_hs_lengthss);
|
||||
static constexpr auto dst_out_hs_lengthss = generate_tuple(
|
||||
[](auto i) {
|
||||
auto outer_i = reversed_outer_hs_lengthss[i];
|
||||
// append the reversed quad output hs lengths to the outer hs lengths
|
||||
return outer_i.push_back(quad_output_hs_lengthss[i]);
|
||||
},
|
||||
number<InDstrEncode::NDimX>{});
|
||||
|
||||
static constexpr auto dst_out_hs_lengthss = tuple_reverse(full_out_hs_lengthss);
|
||||
|
||||
// for PS→RHS mapping(both major and minor), we need to modify the last element of the major
|
||||
// sequence
|
||||
static constexpr auto modified_ps_to_rhss_major = generate_tuple(
|
||||
// for PS→RHS mapping(both major and minor), we need to modify the last element (which is for
|
||||
// thread distr) of the major sequence
|
||||
static constexpr auto dst_ps_to_rhss_major = generate_tuple(
|
||||
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
|
||||
[](auto i) {
|
||||
if constexpr(i == input_ps_to_rhss_major.size() - 1)
|
||||
{
|
||||
constexpr auto current_size = input_ps_to_rhss_major[i].size();
|
||||
constexpr auto reduce_size = quad_ps_to_rhss_major[number<0>{}].size();
|
||||
constexpr auto reduce_size = quad_input_ps_to_rhss_major0.size();
|
||||
constexpr auto quad_out = quad_output_ps_to_rhss_major0;
|
||||
constexpr auto reduced_ps_to_rhss_major = input_ps_to_rhss_major[i].extract(
|
||||
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
|
||||
return reduced_ps_to_rhss_major.push_back(number<2>{});
|
||||
return reduced_ps_to_rhss_major.transform(swap_one_and_two).push_back(quad_out);
|
||||
}
|
||||
else
|
||||
{
|
||||
// For all other sequences, keep them unchanged
|
||||
return input_ps_to_rhss_major[i];
|
||||
// For all other sequences (i.e. warp), keep them unchanged
|
||||
return input_ps_to_rhss_major[i].transform(swap_one_and_two);
|
||||
}
|
||||
},
|
||||
number<input_ps_to_rhss_major.size()>{});
|
||||
|
||||
static constexpr auto minor_last_index =
|
||||
full_out_hs_lengthss[number<InDstrEncode::NDimX - 1>{}].size() - 1;
|
||||
static constexpr auto major_last_index = full_out_hs_lengthss[number<0>{}].size() - 1;
|
||||
static constexpr auto quad_idx_offset =
|
||||
transform_tuples([](auto x) { return number<x.size()>{}; }, reversed_outer_hs_lengthss);
|
||||
|
||||
// minus 1 because RsLength is not counted
|
||||
static constexpr auto quad_output_ps_minor_offset = to_sequence(generate_tuple_for(
|
||||
[](auto x) { return quad_idx_offset[number<x - 1>{}]; }, quad_output_ps_to_rhss_major0));
|
||||
static constexpr auto quad_output_ys_minor_offset = to_sequence(generate_tuple_for(
|
||||
[](auto x) { return quad_idx_offset[number<x - 1>{}]; }, quad_output_ys_to_rhs_major));
|
||||
|
||||
static constexpr auto dst_ps_to_rhss_minor = generate_tuple(
|
||||
[](auto i) {
|
||||
constexpr auto input_i = input_ps_to_rhss_minor[i];
|
||||
if constexpr(i == input_ps_to_rhss_minor.size() - 1)
|
||||
{
|
||||
constexpr auto current_size = input_ps_to_rhss_minor[i].size();
|
||||
constexpr auto reduce_size = quad_ps_to_rhss_minor[number<0>{}].size();
|
||||
constexpr auto reduced_ps_to_rhss_minor = input_ps_to_rhss_minor[i].extract(
|
||||
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
|
||||
return reduced_ps_to_rhss_minor.push_back(number<minor_last_index>{});
|
||||
constexpr auto outer_len = input_i.size() - quad_input_ps_to_rhss_minor0.size();
|
||||
constexpr auto outer_ps =
|
||||
typename sequence_split<decltype(input_i), outer_len>::left_type{};
|
||||
|
||||
return outer_ps.push_back(quad_output_ps_minor_offset +
|
||||
quad_output_ps_to_rhss_minor0);
|
||||
}
|
||||
else
|
||||
{
|
||||
// For all other sequences, keep them unchanged
|
||||
return input_ps_to_rhss_minor[i];
|
||||
return input_i;
|
||||
}
|
||||
},
|
||||
number<input_ps_to_rhss_minor.size()>{});
|
||||
|
||||
static constexpr auto outer_input_ys_to_rhs_major = input_ys_to_rhs_major.pop_back();
|
||||
|
||||
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
|
||||
static constexpr auto swap_one_and_two = [](const index_t idx) {
|
||||
return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
|
||||
};
|
||||
static constexpr auto dst_ps_to_rhss_major = generate_tuple(
|
||||
[](auto i) { return modified_ps_to_rhss_major[i].transform(swap_one_and_two); },
|
||||
number<modified_ps_to_rhss_major.size()>{});
|
||||
static constexpr auto dst_ys_to_rhs_major =
|
||||
outer_input_ys_to_rhs_major.transform(swap_one_and_two).push_back(number<2>{});
|
||||
|
||||
static constexpr auto modified_input_ys_to_rhs_major =
|
||||
input_ys_to_rhs_major.pop_back().push_back(number<1>{});
|
||||
static constexpr auto dst_ys_to_rhs_minor = input_ys_to_rhs_minor.pop_back().push_back(
|
||||
number<(quad_output_ys_minor_offset + quad_output_ys_to_rhs_minor)[I0]>{});
|
||||
|
||||
static constexpr auto dst_ys_to_rhs_major = generate_sequence_v2(
|
||||
[](auto i) { return number<swap_one_and_two(modified_input_ys_to_rhs_major[i])>{}; },
|
||||
number<modified_input_ys_to_rhs_major.size()>{});
|
||||
|
||||
static constexpr auto dst_ys_to_rhs_minor =
|
||||
input_ys_to_rhs_minor.pop_back().push_back(number<major_last_index>{});
|
||||
|
||||
using OutDstrEncode = tile_distribution_encoding<typename InDstrEncode::RsLengths,
|
||||
remove_cvref_t<decltype(dst_out_hs_lengthss)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_major)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_minor)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_major)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_minor)>>;
|
||||
using TransposedDstrEncode =
|
||||
tile_distribution_encoding<typename InDstrEncode::RsLengths,
|
||||
remove_cvref_t<decltype(dst_out_hs_lengthss)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_major)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_minor)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_major)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_minor)>>;
|
||||
};
|
||||
|
||||
template <typename TileDistributionEncoding_,
|
||||
typename DataType_,
|
||||
typename Policy = DefaultTranspose<DataType_>>
|
||||
using OutputTileDistributionTraits =
|
||||
TransposeTileDistributionTraits<TileDistributionEncoding_, DataType_, Policy, false>;
|
||||
template <typename TileDistributionEncoding_,
|
||||
typename DataType_,
|
||||
typename Policy = DefaultTranspose<DataType_>>
|
||||
using InputTileDistributionTraits =
|
||||
TransposeTileDistributionTraits<TileDistributionEncoding_, DataType_, Policy, true>;
|
||||
|
||||
template <typename InnerEncode,
|
||||
index_t kLeadIterPerWarp,
|
||||
index_t kSecondIterPerWarp,
|
||||
@@ -321,9 +405,9 @@ load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window)
|
||||
{
|
||||
using OutTileDstrEncode =
|
||||
typename OutputTileDistributionTraits<TileDistribution_,
|
||||
typename BottomTensorView_::DataType>::OutDstrEncode;
|
||||
using OutTileDstrEncode = typename OutputTileDistributionTraits<
|
||||
typename TileDistribution_::DstrEncode,
|
||||
typename BottomTensorView_::DataType>::TransposedDstrEncode;
|
||||
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
|
||||
make_static_tile_distribution(OutTileDstrEncode{}));
|
||||
auto trans_tensor = tile_window.template load_transpose<Policy>();
|
||||
|
||||
156
include/ck_tile/core/utility/debug.hpp
Normal file
156
include/ck_tile/core/utility/debug.hpp
Normal file
@@ -0,0 +1,156 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <stdio.h>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <auto... val>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
template <typename... type>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
|
||||
template <char... Xs>
|
||||
struct str_literal
|
||||
{
|
||||
static constexpr const char data[] = {Xs..., '\0'};
|
||||
static constexpr const size_t size = sizeof...(Xs);
|
||||
|
||||
template <char... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(str_literal<Ys...> /*rhs*/) const
|
||||
{
|
||||
return str_literal<Xs..., Ys...>{};
|
||||
}
|
||||
|
||||
template <index_t N, char... Ys>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto duplicate_n(const str_literal<Ys...> sep)
|
||||
{
|
||||
if constexpr(N == 0)
|
||||
return str_literal<>{};
|
||||
else if constexpr(N == 1)
|
||||
return str_literal<Xs...>{};
|
||||
else
|
||||
return duplicate_n<N - 1>(sep) + str_literal<Ys..., Xs...>{};
|
||||
}
|
||||
};
|
||||
|
||||
#define make_str_literal(lit_) \
|
||||
std::apply([](auto... indices) { return str_literal<(lit_)[decltype(indices)::value]...>{}; }, \
|
||||
makeTuple(std::make_index_sequence<constexpr_strlen(lit_)>()))
|
||||
|
||||
template <size_t... Idx>
|
||||
constexpr std::tuple<std::integral_constant<size_t, Idx>...>
|
||||
makeTuple(std::index_sequence<Idx...>) noexcept
|
||||
{
|
||||
return {};
|
||||
}
|
||||
constexpr size_t constexpr_strlen(const char* c)
|
||||
{
|
||||
size_t t = 0;
|
||||
while(*c++)
|
||||
++t;
|
||||
return t;
|
||||
}
|
||||
|
||||
template <typename DataType_, typename StaticTileDistribution_>
|
||||
struct static_distributed_tensor;
|
||||
|
||||
template <typename T_, index_t N_>
|
||||
struct thread_buffer;
|
||||
|
||||
// Usage example: CK_PRINTF<float>{}(tensor);
|
||||
template <typename ConvertTo = void,
|
||||
typename FMT = str_literal<>,
|
||||
typename PREFIX = str_literal<>,
|
||||
typename SUFFIX = str_literal<>>
|
||||
struct CK_PRINTF;
|
||||
template <typename ConvertTo, char... FMTChars, char... PREFIXChars, char... SUFFIXChars>
|
||||
struct CK_PRINTF<ConvertTo,
|
||||
str_literal<FMTChars...>,
|
||||
str_literal<PREFIXChars...>,
|
||||
str_literal<SUFFIXChars...>>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto default_format()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, float>)
|
||||
return make_str_literal("%8.3f");
|
||||
else if constexpr(std::is_same_v<T, int>)
|
||||
return make_str_literal("%5d");
|
||||
else if constexpr(std::is_same_v<T, unsigned int>)
|
||||
return make_str_literal("%5u");
|
||||
else
|
||||
return make_str_literal("0x%08x");
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_prefix()
|
||||
{
|
||||
constexpr auto fmt_tid = make_str_literal("tid %03d: [%02d] ");
|
||||
if constexpr(sizeof...(PREFIXChars) == 0)
|
||||
return fmt_tid;
|
||||
else
|
||||
return fmt_tid + make_str_literal(" ") + str_literal<PREFIXChars...>{};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_suffix()
|
||||
{
|
||||
constexpr auto lf = make_str_literal("\n");
|
||||
if constexpr(sizeof...(SUFFIXChars) == 0)
|
||||
return lf;
|
||||
else
|
||||
return str_literal<SUFFIXChars...>{} + lf;
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename Y, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE void impl(const thread_buffer<T, N>& buf,
|
||||
std::integer_sequence<index_t, Is...>) const
|
||||
{
|
||||
using FMT1 = std::conditional_t<sizeof...(FMTChars) == 0,
|
||||
decltype(default_format<Y>()),
|
||||
str_literal<FMTChars...>>;
|
||||
constexpr auto fmt_v = FMT1::template duplicate_n<N>(make_str_literal(" "));
|
||||
constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix();
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
||||
printf(fmt_wrap_v.data, get_thread_id(), N, type_convert<Y>(buf[Is])...);
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_HOST_DEVICE void operator()(const thread_buffer<T, N>& buf) const
|
||||
{
|
||||
using ConvertTo_ = std::conditional_t<std::is_same_v<ConvertTo, void>, T, ConvertTo>;
|
||||
impl<T, N, ConvertTo_>(buf, std::make_integer_sequence<index_t, N>{});
|
||||
}
|
||||
|
||||
template <typename... TS>
|
||||
CK_TILE_HOST_DEVICE void operator()(const static_distributed_tensor<TS...>& tensor) const
|
||||
{
|
||||
return operator()(tensor.get_thread_buffer());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ConvertTo = void,
|
||||
typename FMT = str_literal<>,
|
||||
typename PREFIX = str_literal<>,
|
||||
typename SUFFIX = str_literal<>>
|
||||
struct CK_PRINTF_WARP0 : public CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>
|
||||
{
|
||||
using base_t = CK_PRINTF<ConvertTo, FMT, PREFIX, SUFFIX>;
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(const T& buf) const
|
||||
{
|
||||
if(get_thread_id() < get_warp_size())
|
||||
base_t::operator()(buf);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -15,10 +15,10 @@
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
|
||||
@@ -30,14 +30,14 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
|
||||
|
||||
@@ -122,9 +122,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return 1;
|
||||
// use larger K/V LDS buffer size will lower the occupancy
|
||||
else if constexpr(64 <= kK0 || 64 <= kK1)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
|
||||
@@ -13,9 +13,9 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp"
|
||||
@@ -44,10 +44,10 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -15,6 +16,19 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
constexpr bool is_a_load_tr = std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
|
||||
tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_load_tr = std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
|
||||
tensor_layout::gemm::RowMajor>;
|
||||
#else
|
||||
constexpr bool is_a_load_tr = false;
|
||||
constexpr bool is_b_load_tr = false;
|
||||
#endif
|
||||
constexpr auto wg_attr_num_access = (is_a_load_tr || is_b_load_tr)
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single;
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
@@ -40,14 +54,34 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
|
||||
}
|
||||
#else
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
using WG = WarpGemmMfmaDispatcher<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
return make_tuple(WG{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
using WG = WarpGemmMfmaDispatcher<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
return make_tuple(WG{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -218,10 +218,16 @@ struct BlockUniversalGemmAsBsCr
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <typename CBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
@@ -300,14 +306,23 @@ struct BlockUniversalGemmAsBsCr
|
||||
ALdsTile a_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else if constexpr(ALoadTranspose)
|
||||
{
|
||||
a_warp_tile_ = load_tile_transpose(a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
@@ -316,6 +331,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else if constexpr(BLoadTranspose)
|
||||
{
|
||||
b_warp_tile_ = load_tile_transpose(b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_block_window);
|
||||
@@ -323,10 +342,16 @@ struct BlockUniversalGemmAsBsCr
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <typename CBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
[[maybe_unused]] ASmemBlockWindow& a_block_window,
|
||||
[[maybe_unused]] BSmemBlockWindow& b_block_window)
|
||||
[[maybe_unused]] BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
@@ -382,40 +407,73 @@ struct BlockUniversalGemmAsBsCr
|
||||
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
static constexpr auto ALdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
make_static_tile_distribution(MakeABlockDistributionEncode());
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
make_static_tile_distribution(MakeBBlockDistributionEncode());
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
ALdsTile b_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
template <index_t KIdx, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <index_t KIdx,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(MakeBBlockDistributionEncode());
|
||||
constexpr auto a_lds_load_distr = [&]() {
|
||||
if constexpr(ALoadTranspose)
|
||||
return make_static_tile_distribution(typename InputTileDistributionTraits<
|
||||
decltype(MakeABlockDistributionEncode()),
|
||||
ADataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return make_static_tile_distribution(MakeABlockDistributionEncode());
|
||||
}();
|
||||
constexpr auto b_lds_load_distr = [&]() {
|
||||
if constexpr(BLoadTranspose)
|
||||
return make_static_tile_distribution(typename InputTileDistributionTraits<
|
||||
decltype(MakeBBlockDistributionEncode()),
|
||||
BDataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return make_static_tile_distribution(MakeBBlockDistributionEncode());
|
||||
}();
|
||||
constexpr auto a_lds_shape = []() {
|
||||
if constexpr(ALoadTranspose)
|
||||
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::MPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{});
|
||||
}();
|
||||
constexpr auto b_lds_shape = []() {
|
||||
if constexpr(BLoadTranspose)
|
||||
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::NPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{});
|
||||
}();
|
||||
constexpr auto k_idx_offset = KIdx * KPerInnerLoop;
|
||||
constexpr auto a_offset =
|
||||
ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
|
||||
constexpr auto b_offset =
|
||||
BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
|
||||
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{}),
|
||||
{0, KIdx * KPerInnerLoop},
|
||||
a_lds_load_tile_distr);
|
||||
a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr);
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{}),
|
||||
{0, KIdx * KPerInnerLoop},
|
||||
b_lds_load_tile_distr);
|
||||
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else if constexpr(ALoadTranspose)
|
||||
{
|
||||
a_warp_tile_ = load_tile_transpose(a_lds_gemm_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_lds_gemm_window);
|
||||
@@ -424,6 +482,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else if constexpr(BLoadTranspose)
|
||||
{
|
||||
b_warp_tile_ = load_tile_transpose(b_lds_gemm_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_lds_gemm_window);
|
||||
@@ -431,10 +493,16 @@ struct BlockUniversalGemmAsBsCr
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <typename CBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
@@ -442,7 +510,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KRepeat, 1>{}([&](auto kIter) {
|
||||
LocalPrefetch<kIter.value>(a_block_window, b_block_window);
|
||||
LocalPrefetch<kIter.value>(a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// NOTE: Synchronize threads in a workgroup at the start of each MAC
|
||||
// cluster, but except the first, as we can shorten non-MAC cluster a bit
|
||||
@@ -543,29 +611,45 @@ struct BlockUniversalGemmAsBsCr
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
|
||||
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <typename CBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window);
|
||||
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE auto operator()(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window);
|
||||
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,13 @@ struct GemmPipelineAgBgCrImplBase
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
#if defined(__gfx950__)
|
||||
static constexpr bool is_a_load_tr = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
static constexpr bool is_b_load_tr = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
#else
|
||||
static constexpr bool is_a_load_tr = false;
|
||||
static constexpr bool is_b_load_tr = false;
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
@@ -50,11 +57,15 @@ struct GemmPipelineAgBgCrImplBase
|
||||
store_tile(lds_tile_window, block_tile_tmp);
|
||||
}
|
||||
|
||||
template <typename DstBlockTile, typename SrcTileWindow>
|
||||
template <typename DstBlockTile, typename SrcTileWindow, bool LoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile,
|
||||
const SrcTileWindow& lds_tile_window) const
|
||||
const SrcTileWindow& lds_tile_window,
|
||||
bool_constant<LoadTranspose> = {}) const
|
||||
{
|
||||
load_tile(dst_block_tile, lds_tile_window);
|
||||
if constexpr(LoadTranspose)
|
||||
dst_block_tile = load_tile_transpose(lds_tile_window);
|
||||
else
|
||||
load_tile(dst_block_tile, lds_tile_window);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
|
||||
@@ -96,14 +107,25 @@ struct GemmPipelineAgBgCrImplBase
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window = make_tile_window(
|
||||
a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
auto a_lds_shape = []() {
|
||||
if constexpr(is_a_load_tr)
|
||||
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
|
||||
}();
|
||||
auto a_copy_lds_window = make_tile_window(a_lds_block_view, a_lds_shape, {0, 0});
|
||||
|
||||
auto a_lds_load_tile_distr = []() {
|
||||
if constexpr(is_a_load_tr)
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
typename ALdsLoadTileDistr::DstrEncode,
|
||||
typename Problem::ADataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return ALdsLoadTileDistr{};
|
||||
}();
|
||||
auto a_lds_gemm_window =
|
||||
make_tile_window(a_lds_block_view,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
ALdsLoadTileDistr{});
|
||||
make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}, a_lds_load_tile_distr);
|
||||
|
||||
return make_tuple(std::move(a_copy_dram_window),
|
||||
std::move(a_copy_lds_window),
|
||||
@@ -130,14 +152,25 @@ struct GemmPipelineAgBgCrImplBase
|
||||
// TODO: Do we really need those two tile windows???
|
||||
// They're exactly same...
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window = make_tile_window(
|
||||
b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
auto b_lds_shape = []() {
|
||||
if constexpr(is_b_load_tr)
|
||||
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
|
||||
}();
|
||||
auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
|
||||
|
||||
auto b_lds_load_tile_distr = []() {
|
||||
if constexpr(is_b_load_tr)
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
typename BLdsLoadTileDistr::DstrEncode,
|
||||
typename Problem::BDataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return BLdsLoadTileDistr{};
|
||||
}();
|
||||
auto b_lds_gemm_window =
|
||||
make_tile_window(b_lds_block_view,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
BLdsLoadTileDistr{});
|
||||
make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}, b_lds_load_tile_distr);
|
||||
|
||||
return make_tuple(std::move(b_copy_dram_window),
|
||||
std::move(b_copy_lds_window),
|
||||
|
||||
@@ -153,6 +153,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
Problem::TailNum; // Base::GetBlockLoopTailNum(Problem::num_loop);
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
using Base::PrefetchStages;
|
||||
using Base::UsePersistentKernel;
|
||||
|
||||
@@ -467,7 +470,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -478,7 +481,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -494,7 +497,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -506,7 +510,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -517,7 +521,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -536,7 +540,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -578,7 +583,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
}
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -141,6 +141,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
@@ -305,17 +308,23 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
|
||||
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
|
||||
|
||||
auto a_copy_lds_window0 = make_tile_window(
|
||||
a_lds_block0, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
constexpr auto a_lds_shape = []() {
|
||||
if constexpr(is_a_load_tr_v())
|
||||
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
|
||||
}();
|
||||
auto a_copy_lds_window0 = make_tile_window(a_lds_block0, a_lds_shape, {0, 0});
|
||||
auto a_copy_lds_window1 = make_tile_window(a_lds_block1, a_lds_shape, {0, 0});
|
||||
|
||||
auto a_copy_lds_window1 = make_tile_window(
|
||||
a_lds_block1, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
|
||||
auto b_copy_lds_window0 = make_tile_window(
|
||||
b_lds_block0, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
|
||||
auto b_copy_lds_window1 = make_tile_window(
|
||||
b_lds_block1, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
constexpr auto b_lds_shape = []() {
|
||||
if constexpr(is_b_load_tr_v())
|
||||
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
|
||||
}();
|
||||
auto b_copy_lds_window0 = make_tile_window(b_lds_block0, b_lds_shape, {0, 0});
|
||||
auto b_copy_lds_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
@@ -325,7 +334,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -336,7 +345,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -354,51 +363,53 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeABlockDistributionEncode())){};
|
||||
constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeBBlockDistributionEncode())){};
|
||||
constexpr auto ALdsTileDistr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto BLdsTileDistr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
ALdsTile a_block_tile0, a_block_tile1;
|
||||
BLdsTile b_block_tile0, b_block_tile1;
|
||||
|
||||
ALdsTile a_block_tile0;
|
||||
ALdsTile a_block_tile1;
|
||||
|
||||
BLdsTile b_block_tile0;
|
||||
BLdsTile b_block_tile1;
|
||||
|
||||
constexpr auto a_lds_input_tile_distr = [&]() {
|
||||
if constexpr(is_a_load_tr_v())
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
decltype(BlockGemm::MakeABlockDistributionEncode()),
|
||||
typename Problem::ADataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return ALdsTileDistr;
|
||||
}();
|
||||
constexpr auto b_lds_input_tile_distr = [&]() {
|
||||
if constexpr(is_b_load_tr_v())
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
decltype(BlockGemm::MakeBBlockDistributionEncode()),
|
||||
typename Problem::BDataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return BLdsTileDistr;
|
||||
}();
|
||||
auto a_lds_ld_window0 =
|
||||
make_tile_window(a_lds_block0,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
ALdsTileDistr);
|
||||
make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
|
||||
auto a_lds_ld_window1 =
|
||||
make_tile_window(a_lds_block1,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
ALdsTileDistr);
|
||||
make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
|
||||
auto b_lds_ld_window0 =
|
||||
make_tile_window(b_lds_block0,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
BLdsTileDistr);
|
||||
make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
|
||||
auto b_lds_ld_window1 =
|
||||
make_tile_window(b_lds_block1,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
BLdsTileDistr);
|
||||
make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
|
||||
|
||||
static_assert(
|
||||
!(is_tile_window_linear_v<decltype(a_lds_ld_window0)>)&&!(is_tile_window_linear_v<decltype(a_lds_ld_window1)>)&&!(
|
||||
is_tile_window_linear_v<
|
||||
decltype(b_lds_ld_window0)>)&&!(is_tile_window_linear_v<decltype(b_lds_ld_window1)>),
|
||||
"LDS windows must not be linear");
|
||||
static_assert(!is_tile_window_linear_v<decltype(a_lds_ld_window0)> &&
|
||||
!is_tile_window_linear_v<decltype(a_lds_ld_window1)> &&
|
||||
!is_tile_window_linear_v<decltype(b_lds_ld_window0)> &&
|
||||
!is_tile_window_linear_v<decltype(b_lds_ld_window1)>,
|
||||
"LDS windows must not be linear");
|
||||
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -409,7 +420,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -433,10 +444,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
// ping
|
||||
{
|
||||
block_sync_lds();
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -448,7 +459,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
Base::LocalPrefill(
|
||||
a_copy_lds_window0, a_global_load_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -473,10 +484,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
// pong
|
||||
{
|
||||
block_sync_lds();
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -488,7 +499,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
Base::LocalPrefill(
|
||||
a_copy_lds_window1, a_global_load_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -521,9 +532,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
// 3
|
||||
{
|
||||
block_sync_lds();
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
|
||||
if constexpr(is_a_col_major)
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -534,7 +545,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -550,8 +561,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
// 2
|
||||
{
|
||||
block_sync_lds();
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
}
|
||||
// 1
|
||||
@@ -565,8 +576,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
// 2
|
||||
{
|
||||
block_sync_lds();
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
static_for<0, 8, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
|
||||
@@ -21,15 +21,27 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
// using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
constexpr bool single_load_tr_length =
|
||||
(DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType)) ==
|
||||
(WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size());
|
||||
constexpr auto wg_attr_num_access =
|
||||
((is_a_load_tr<Problem> || is_b_load_tr<Problem>)&&!single_load_tr_length)
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType, // AccDataType
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
|
||||
@@ -196,6 +196,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
@@ -272,10 +275,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto& b_lds_block = ab_lds_blocks.at(I1{});
|
||||
|
||||
// Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeABlockDistributionEncode())){};
|
||||
constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeBBlockDistributionEncode())){};
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
// A DRAM tile window for load
|
||||
// A LDS tile window for store
|
||||
@@ -332,7 +335,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -343,7 +346,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -373,12 +376,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
{
|
||||
static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -394,7 +398,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -427,12 +431,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static_for<1, tail_num, 1>{}([&](auto prefetch_idx) {
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -445,7 +450,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -461,14 +466,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
};
|
||||
|
||||
if constexpr(TailNum == TailNumber::One)
|
||||
{
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Two)
|
||||
@@ -558,10 +565,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto& b_lds_block = ab_lds_blocks.at(I1{});
|
||||
|
||||
// Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeABlockDistributionEncode())){};
|
||||
constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeBBlockDistributionEncode())){};
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
// A DRAM tile window for load
|
||||
// A LDS tile window for store
|
||||
@@ -617,7 +624,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -628,7 +635,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -658,10 +665,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
{
|
||||
static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile,
|
||||
a_lds_gemm_window,
|
||||
b_lds_gemm_window,
|
||||
is_a_load_tr_v,
|
||||
is_b_load_tr_v);
|
||||
// no second block_sync_lds because it's interwave
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -677,7 +688,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -709,10 +720,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto HotLoopTail = [&](auto tail_num) {
|
||||
static_for<1, tail_num, 1>{}([&](auto prefetch_idx) {
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile,
|
||||
a_lds_gemm_window,
|
||||
b_lds_gemm_window,
|
||||
is_a_load_tr_v,
|
||||
is_b_load_tr_v);
|
||||
// no second block_sync_lds because it's interwave
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -725,7 +740,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -741,13 +756,21 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile,
|
||||
a_lds_gemm_window,
|
||||
b_lds_gemm_window,
|
||||
is_a_load_tr_v,
|
||||
is_b_load_tr_v);
|
||||
};
|
||||
|
||||
if constexpr(TailNum == TailNumber::One)
|
||||
{
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile,
|
||||
a_lds_gemm_window,
|
||||
b_lds_gemm_window,
|
||||
is_a_load_tr_v,
|
||||
is_b_load_tr_v);
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Two)
|
||||
{
|
||||
|
||||
@@ -47,6 +47,8 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr index_t kLdsAlignmentInBytes = 16;
|
||||
|
||||
@@ -49,6 +49,9 @@ struct GemmPipelineProblemBase
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
|
||||
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
||||
|
||||
// In the base situation, the Preshuffle setting should be false.
|
||||
static constexpr bool Preshuffle = false;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
@@ -12,6 +12,20 @@ namespace ck_tile {
|
||||
template <typename Derived>
|
||||
struct UniversalGemmBasePolicy
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
template <typename Problem>
|
||||
static constexpr bool is_a_load_tr =
|
||||
std::is_same_v<remove_cvref_t<typename Problem::ALayout>, tensor_layout::gemm::ColumnMajor>;
|
||||
template <typename Problem>
|
||||
static constexpr bool is_b_load_tr =
|
||||
std::is_same_v<remove_cvref_t<typename Problem::BLayout>, tensor_layout::gemm::RowMajor>;
|
||||
#else
|
||||
template <typename Problem>
|
||||
static constexpr bool is_a_load_tr = false;
|
||||
template <typename Problem>
|
||||
static constexpr bool is_b_load_tr = false;
|
||||
#endif
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
@@ -22,51 +36,65 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
constexpr auto MLdsLayer =
|
||||
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
|
||||
if constexpr(is_a_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better lds descriptor for performance
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
|
||||
make_tuple(number<KPerBlock>{}, number<MPerBlock>{}),
|
||||
make_tuple(number<MPerBlock>{}, number<1>{}),
|
||||
number<MPerBlock>{},
|
||||
number<1>{});
|
||||
return a_lds_block_desc_0;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
|
||||
number<MPerBlock / MLdsLayer>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
constexpr auto MLdsLayer =
|
||||
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
|
||||
number<KPerBlock / KPack * MLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
|
||||
number<MPerBlock / MLdsLayer>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<MLdsLayer>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
|
||||
number<KPerBlock / KPack * MLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<MLdsLayer>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -78,14 +106,24 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
// using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
#if 1
|
||||
// if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
if constexpr(is_b_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better lds descriptor for performance
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( //
|
||||
make_tuple(number<KPerBlock>{}, number<NPerBlock>{}),
|
||||
make_tuple(number<NPerBlock>{}, number<1>{}),
|
||||
number<NPerBlock>{},
|
||||
number<1>{});
|
||||
return b_lds_block_desc_0;
|
||||
}
|
||||
else
|
||||
// else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t KPack = GetSmemPackB<Problem>();
|
||||
constexpr auto BK0 = number<KPerBlock / KPack>{};
|
||||
@@ -584,8 +622,19 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
constexpr index_t vector_size =
|
||||
DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
|
||||
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
|
||||
constexpr auto wg_attr_num_access =
|
||||
!(is_a_load_tr<Problem> || is_b_load_tr<Problem>) ? WGAttrNumAccessEnum::Single
|
||||
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
|
||||
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
|
||||
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
|
||||
: WGAttrNumAccessEnum::Invalid;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
@@ -594,7 +643,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
Problem::UseStructuredSparsity>;
|
||||
Problem::UseStructuredSparsity,
|
||||
wg_attr_num_access>;
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
|
||||
@@ -84,7 +84,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
static constexpr bool Preshuffle = Problem::Preshuffle;
|
||||
using Base::UsePersistentKernel;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
|
||||
@@ -21,22 +21,29 @@ using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
|
||||
@@ -56,25 +63,33 @@ using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
@@ -123,22 +138,29 @@ using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
|
||||
@@ -159,25 +181,33 @@ using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution =
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
@@ -247,17 +277,25 @@ using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfma<
|
||||
using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfma<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -8,10 +8,22 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_>
|
||||
// Number of groups of consecutive elements to fill in a ABKLane
|
||||
enum class WGAttrNumAccessEnum
|
||||
{
|
||||
Single = 1,
|
||||
Double = 2,
|
||||
Quad = 4,
|
||||
Invalid = -1
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAtrributeMfma
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
@@ -31,21 +43,35 @@ struct WarpGemmAtrributeMfma
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
template <index_t kMNLane>
|
||||
static constexpr auto get_warp_dstr_encoding()
|
||||
{
|
||||
if constexpr(AttrNumAccessV == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKPerThread % AttrNumAccessV == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
|
||||
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
@@ -73,12 +99,16 @@ struct WarpGemmAtrributeMfma
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter>
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
index_t kKIter,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAtrributeMfmaIterateK
|
||||
{
|
||||
static_assert(kKIter > 0, "wrong!");
|
||||
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
@@ -104,17 +134,37 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
if constexpr(AttrNumAccessV == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKPerThread % AttrNumAccessV == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<AttrNumAccessV,
|
||||
Impl::kABKLane,
|
||||
Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
static_assert(AttrNumAccessV == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// each M blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kBNBlock>,
|
||||
@@ -127,6 +177,8 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
static_assert(AttrNumAccessV == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
@@ -143,17 +195,38 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
if constexpr(AttrNumAccessV == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
static_assert(kKPerThread % AttrNumAccessV == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<AttrNumAccessV,
|
||||
Impl::kABKLane,
|
||||
Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
static_assert(AttrNumAccessV == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
@@ -166,6 +239,8 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
static_assert(AttrNumAccessV == 1,
|
||||
"Multiple access is not supported when using multi-block");
|
||||
// each N blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kAMBlock>,
|
||||
@@ -289,10 +364,13 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_>
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAtrributeMfmaTransposedCDistribution
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
|
||||
|
||||
using ADataType = typename Impl::BDataType;
|
||||
using BDataType = typename Impl::ADataType;
|
||||
@@ -312,21 +390,35 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
template <index_t kMNLane>
|
||||
static constexpr auto get_warp_dstr_encoding()
|
||||
{
|
||||
if constexpr(AttrNumAccessV == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKPerThread % AttrNumAccessV == 0,
|
||||
"kKPerThread must be divisible by NumAccess");
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kMNLane>,
|
||||
sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
|
||||
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
@@ -450,10 +542,13 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter>
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
index_t kKIter,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
|
||||
// swap A and B
|
||||
using ADataType = typename Impl::BDataType;
|
||||
@@ -478,80 +573,14 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
// each N blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kAMBlock>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
return WarpGemmAtrributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
|
||||
get_bwarp_dstr_encoding();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
// each M blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kBNBlock>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMBlock, Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
return WarpGemmAtrributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
|
||||
get_awarp_dstr_encoding();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
|
||||
|
||||
@@ -16,8 +16,9 @@ template <typename AType,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false>
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmMfmaDispatcher;
|
||||
|
||||
// clang-format off
|
||||
@@ -25,12 +26,20 @@ struct WarpGemmMfmaDispatcher;
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M32N32K16<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M16N16K32<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaF16F16F32M4N64K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaF16F16F32M64N4K16; };
|
||||
|
||||
@@ -48,12 +57,20 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; };
|
||||
|
||||
@@ -84,10 +101,18 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float,
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; };
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<WGAttrNumAccessEnum::Quad>; };
|
||||
|
||||
// int8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
@@ -106,8 +131,9 @@ template <typename AType,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false>
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
|
||||
BType,
|
||||
AccType,
|
||||
@@ -116,6 +142,7 @@ using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
|
||||
KPerWave,
|
||||
TransposeC,
|
||||
SwizzleA,
|
||||
UseStructuredSparsity>::Type;
|
||||
UseStructuredSparsity,
|
||||
AttrNumAccess>::Type;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user