mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
fix merge from upstream
This commit is contained in:
@@ -62,10 +62,10 @@ template <index_t NumDimG,
|
||||
index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
@@ -73,13 +73,14 @@ template <index_t NumDimG,
|
||||
TensorSpecialization ASpec,
|
||||
TensorSpecialization BSpec,
|
||||
TensorSpecialization DESpec,
|
||||
ck::index_t NumPrefetch,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
@@ -100,7 +101,6 @@ template <index_t NumDimG,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::index_t NumPrefetch = 1,
|
||||
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
@@ -123,15 +123,32 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
|
||||
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = false;
|
||||
static constexpr auto BEnableLds_manu = false;
|
||||
|
||||
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
|
||||
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
// Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
|
||||
static auto MakeAGridDescriptor(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
|
||||
{
|
||||
assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK &&
|
||||
a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK);
|
||||
@@ -158,36 +175,72 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
// lengths for K0, K1, ...
|
||||
const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
|
||||
|
||||
if constexpr(ASpec == TensorSpecialization::Packed)
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(ASpec == TensorSpecialization::Packed)
|
||||
{
|
||||
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
|
||||
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
|
||||
const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(M, K),
|
||||
make_tuple(a_ms_ks_strides[Number<NumDimM - 1>{}],
|
||||
a_ms_ks_strides[Number<NumDimM + NumDimK - 1>{}]));
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
|
||||
const auto a_grid_desc_ms_ks =
|
||||
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
|
||||
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
|
||||
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
|
||||
a_grid_desc_ms_ks,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(mDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
|
||||
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
|
||||
const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(M, K),
|
||||
make_tuple(a_ms_ks_strides[Number<NumDimM - 1>{}],
|
||||
a_ms_ks_strides[Number<NumDimM + NumDimK - 1>{}]));
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
|
||||
const auto a_grid_desc_ms_ks =
|
||||
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
constexpr auto A_KRow = 2;
|
||||
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
|
||||
const auto A_KWmma = K / WmmaK;
|
||||
|
||||
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
|
||||
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
|
||||
a_grid_desc_ms_ks,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(mDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
const auto M0 = M / MPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
|
||||
static auto MakeBGridDescriptor(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
|
||||
{
|
||||
assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK &&
|
||||
b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK);
|
||||
@@ -214,30 +267,66 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
// lengths for N0, N1, ...
|
||||
const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
|
||||
|
||||
if constexpr(BSpec == TensorSpecialization::Packed)
|
||||
const auto b_grid_desc_n_k = [&]() {
|
||||
if constexpr(BSpec == TensorSpecialization::Packed)
|
||||
{
|
||||
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
|
||||
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
|
||||
const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(N, K),
|
||||
make_tuple(b_ns_ks_strides[Number<NumDimN - 1>{}],
|
||||
b_ns_ks_strides[Number<NumDimN + NumDimK - 1>{}]));
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
const auto b_grid_desc_ns_ks =
|
||||
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
|
||||
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
|
||||
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
|
||||
b_grid_desc_ns_ks,
|
||||
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(nDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
|
||||
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
|
||||
const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(N, K),
|
||||
make_tuple(b_ns_ks_strides[Number<NumDimN - 1>{}],
|
||||
b_ns_ks_strides[Number<NumDimN + NumDimK - 1>{}]));
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
const auto b_grid_desc_ns_ks =
|
||||
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
|
||||
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
|
||||
b_grid_desc_ns_ks,
|
||||
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(nDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -393,8 +482,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
// Gridwise descriptor, mapping to whole given provblem.
|
||||
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {}));
|
||||
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {}));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
|
||||
|
||||
@@ -449,45 +536,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
EGridDesc_G_M_N e_grid_desc_g_m_n_;
|
||||
};
|
||||
|
||||
// A desc for source in blockwise copy
|
||||
template <typename AGridDesc_M_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGridDescriptor_K0_M_K1(const AGridDesc_M_K& a_grid_desc_m_k)
|
||||
{
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
// B desc for source in blockwise copy
|
||||
template <typename BGridDesc_N_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBGridDescriptor_K0_N_K1(const BGridDesc_N_K& b_grid_desc_n_k)
|
||||
{
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
using AGridDesc_K0_M_K1 = decltype(DeviceOp::MakeAGridDescriptor_K0_M_K1(AGridDesc_M_K{}));
|
||||
using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1(BGridDesc_N_K{}));
|
||||
using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({}, {}));
|
||||
using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({}, {}));
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
using GridwiseOp = GridwiseGemmMultipleD_Wmma<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -496,8 +549,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
DsDataType,
|
||||
EDataType,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
AGridDesc,
|
||||
BGridDesc,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
// ElementwiseOp Family
|
||||
@@ -508,9 +561,9 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
@@ -523,6 +576,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
AEnableLds,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -531,6 +585,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BEnableLds,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
@@ -564,16 +619,14 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
|
||||
a_grid_desc_m_k_{},
|
||||
b_grid_desc_n_k_{},
|
||||
a_grid_desc_{},
|
||||
b_grid_desc_{},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{},
|
||||
ds_grid_desc_g_m_n_{
|
||||
DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)},
|
||||
e_grid_desc_g_m_n_{
|
||||
DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
block_2_ctile_map_{},
|
||||
@@ -600,10 +653,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
|
||||
});
|
||||
|
||||
a_grid_desc_m_k_ =
|
||||
DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
|
||||
b_grid_desc_n_k_ =
|
||||
DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
|
||||
a_grid_desc_ = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
|
||||
b_grid_desc_ = DeviceOp::MakeBGridDescriptor(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
|
||||
|
||||
ds_grid_desc_m_n_ =
|
||||
DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides);
|
||||
@@ -611,9 +662,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
e_grid_desc_m_n_ =
|
||||
DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
|
||||
|
||||
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_);
|
||||
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_);
|
||||
|
||||
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
@@ -644,16 +692,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// Tensor Descriptors
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
AGridDesc a_grid_desc_;
|
||||
BGridDesc b_grid_desc_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
DsGridDesc_G_M_N ds_grid_desc_g_m_n_;
|
||||
EGridDesc_G_M_N e_grid_desc_g_m_n_;
|
||||
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -686,6 +731,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
|
||||
// Batch Offset
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
|
||||
|
||||
// for checking vector load/store
|
||||
// index_t MRaw_;
|
||||
// index_t NRaw_;
|
||||
// index_t KRaw_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -700,8 +750,17 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
const auto K = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
|
||||
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
@@ -712,8 +771,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
BDataType,
|
||||
typename GridwiseOp::DsGridPointer,
|
||||
EDataType,
|
||||
DeviceOp::AGridDesc_K0_M_K1,
|
||||
DeviceOp::BGridDesc_K0_N_K1,
|
||||
DeviceOp::AGridDesc,
|
||||
DeviceOp::BGridDesc,
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
AElementwiseOperation,
|
||||
@@ -733,8 +792,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
G,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
@@ -774,6 +833,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
printf("DeviceOp: Arch check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -782,12 +842,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
printf("GridwiseOp: Validity check failure\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -800,16 +861,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
if constexpr(ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
if(!(arg.a_mz_stride_ == 1 &&
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
|
||||
arg.a_grid_desc_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
printf("DeviceOp: Vector Access A-m check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.a_kz_stride_ == 1 &&
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
|
||||
arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
printf("DeviceOp: Vector Access A-k check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -818,16 +881,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
if constexpr(BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
if(!(arg.b_nz_stride_ == 1 &&
|
||||
arg.b_grid_desc_k0_n_k1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
|
||||
arg.b_grid_desc_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
printf("DeviceOp: Vector Access B-n check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.b_kz_stride_ == 1 &&
|
||||
arg.b_grid_desc_k0_n_k1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
|
||||
arg.b_grid_desc_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
printf("DeviceOp: Vector Access B-k check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -841,6 +906,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock ==
|
||||
0))
|
||||
{
|
||||
printf("DeviceOp: Vector Access D-n check failure\n");
|
||||
valid_d_access = false;
|
||||
}
|
||||
});
|
||||
@@ -857,6 +923,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
0) ||
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1))
|
||||
{
|
||||
printf("DeviceOp: Vector Access E-n check failure\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -967,14 +1034,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWMMA << ", "
|
||||
<< NPerWMMA << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " NumPrefetch: "
|
||||
<< " AEnableLds: "
|
||||
<< AEnableLds << ", "
|
||||
<< "BEnableLds: "
|
||||
<< BEnableLds << ", "
|
||||
<< "NumPrefetch: "
|
||||
<< NumPrefetch << ", "
|
||||
<< "LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -602,7 +602,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!ck::is_lds_direct_load_supported() && std::is_same<ADataType, double>::value)
|
||||
if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" &&
|
||||
ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" &&
|
||||
std::is_same<ADataType, double>::value)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -294,7 +294,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
if((ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
|
||||
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
ck::get_device_name() == "gfx942"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,422 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
index_t NumDim,
|
||||
index_t BlockSize,
|
||||
index_t M0PerBlock,
|
||||
index_t M1PerBlock,
|
||||
index_t M0PerThread,
|
||||
index_t M1PerThread,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename InScalarPerVectorSeq,
|
||||
typename OutScalarPerVectorSeq>
|
||||
struct DeviceElementwiseImpl
|
||||
: public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
|
||||
{
|
||||
static constexpr int NumInput = InDataTypeTuple::Size();
|
||||
static constexpr int NumOutput = OutDataTypeTuple::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
|
||||
NumOutput == OutScalarPerVectorSeq::Size(),
|
||||
"Tuple size is inconsistent with the number of in/out!");
|
||||
|
||||
static auto GenerateInDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<const DataType*>(nullptr);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
};
|
||||
|
||||
static auto GenerateOutDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<DataType*>(nullptr);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
};
|
||||
|
||||
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
|
||||
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
|
||||
|
||||
static index_t GetLowestStrideDim(const std::array<index_t, NumDim>& strides)
|
||||
{
|
||||
index_t most_continous_dim = NumDim - 1;
|
||||
index_t most_continous_dim_stride = strides[most_continous_dim];
|
||||
for(index_t dim = 0; dim < NumDim; dim++)
|
||||
{
|
||||
if(strides[dim] < most_continous_dim_stride)
|
||||
{
|
||||
most_continous_dim_stride = strides[dim];
|
||||
most_continous_dim = dim;
|
||||
}
|
||||
}
|
||||
return most_continous_dim;
|
||||
}
|
||||
|
||||
template <typename InOutDescriptor>
|
||||
static auto PadInputOutputDescriptor(const InOutDescriptor& desc)
|
||||
{
|
||||
const auto M0 = desc.GetLength(I0);
|
||||
const auto M1 = desc.GetLength(I1);
|
||||
const auto pad_M0 = math::integer_divide_ceil(M0, M0PerThread) * M0PerThread - M0;
|
||||
const auto pad_M1 = math::integer_divide_ceil(M1, M1PerThread) * M1PerThread - M1;
|
||||
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_right_pad_transform(M0, pad_M0), make_right_pad_transform(M1, pad_M1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return padded_desc;
|
||||
}
|
||||
|
||||
static auto GenerateBatchDimsLenghtsTuple(const std::array<index_t, NumDim>& lengths,
|
||||
const index_t M0_dim,
|
||||
const index_t M1_dim)
|
||||
{
|
||||
// Generate batch dims, they will be merged to M0
|
||||
// Add one more dim than needed in case that M0 is equal to M1
|
||||
// If M0 is equal to M1, then will be one more batch dim
|
||||
std::array<index_t, NumDim - 1> batch_dims;
|
||||
index_t batch_dim = 0;
|
||||
for(index_t i = 0; i < NumDim; i++)
|
||||
{
|
||||
if(i != M0_dim && i != M1_dim)
|
||||
{
|
||||
batch_dims[batch_dim] = lengths[i];
|
||||
batch_dim++;
|
||||
}
|
||||
}
|
||||
// Add dummy dim if M0_dim is not equal to M1_dim
|
||||
if(M0_dim != M1_dim && NumDim >= 2)
|
||||
batch_dims[NumDim - 2] = 1;
|
||||
return generate_tuple([&](auto I) { return batch_dims[I]; }, Number<NumDim - 1>{});
|
||||
}
|
||||
|
||||
static auto MakeDescriptor(const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& in_strides,
|
||||
const std::array<index_t, NumDim>& out_strides,
|
||||
const std::array<index_t, NumDim>& desc_strides)
|
||||
{
|
||||
const auto M0_dim = GetLowestStrideDim(out_strides);
|
||||
const auto M1_dim = GetLowestStrideDim(in_strides);
|
||||
|
||||
// If M0_dim is equal to M1_dim, then make M0_dim dummy
|
||||
const auto M0 = M0_dim == M1_dim ? I1 : lengths[M0_dim];
|
||||
const auto M1 = lengths[M1_dim];
|
||||
const auto M0_stride = M0_dim == M1_dim ? I1 : desc_strides[M0_dim];
|
||||
const auto M1_stride = desc_strides[M1_dim];
|
||||
|
||||
const auto batch_dims_lenghts = GenerateBatchDimsLenghtsTuple(lengths, M0_dim, M1_dim);
|
||||
const auto batch_dims_strides = GenerateBatchDimsLenghtsTuple(desc_strides, M0_dim, M1_dim);
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
concat_tuple(batch_dims_lenghts, make_tuple(M0), make_tuple(M1)),
|
||||
concat_tuple(batch_dims_strides, make_tuple(M0_stride), make_tuple(M1_stride)));
|
||||
// Merged batch dims with M0
|
||||
const auto transforms =
|
||||
make_tuple(make_merge_transform(concat_tuple(batch_dims_lenghts, make_tuple(M0))),
|
||||
make_pass_through_transform(M1));
|
||||
using BatchElemsSequence =
|
||||
typename arithmetic_sequence_gen<0, decltype(batch_dims_lenghts)::Size() + 1, 1>::type;
|
||||
const auto lower_dims = make_tuple(BatchElemsSequence{}, Sequence<NumDim>{});
|
||||
const auto upper_dims = make_tuple(Sequence<0>{}, Sequence<1>{});
|
||||
// desc: (merged_dims + M0, M1)
|
||||
auto merged_desc = transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
|
||||
return PadInputOutputDescriptor(merged_desc);
|
||||
}
|
||||
|
||||
template <index_t NumTensors>
|
||||
static auto GenerateInOutGridDescTuple()
|
||||
{
|
||||
std::array<index_t, NumDim> ones;
|
||||
for(index_t d = 0; d < NumDim; d++)
|
||||
{
|
||||
ones[d] = 1;
|
||||
}
|
||||
|
||||
return generate_tuple([&](auto) { return MakeDescriptor(ones, ones, ones, ones); },
|
||||
Number<NumTensors>{});
|
||||
};
|
||||
|
||||
using InGridDescTuple = decltype(GenerateInOutGridDescTuple<NumInput>());
|
||||
using OutGridDescTuple = decltype(GenerateInOutGridDescTuple<NumOutput>());
|
||||
|
||||
using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt<M0PerBlock, M1PerBlock>;
|
||||
|
||||
using GridwiseElementwiseOp = GridwiseElementwise<InGridDescTuple,
|
||||
OutGridDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
Block2TileMap,
|
||||
ElementwiseOperation,
|
||||
BlockSize,
|
||||
M0PerBlock,
|
||||
M1PerBlock,
|
||||
M0PerThread,
|
||||
M1PerThread,
|
||||
ThreadClusterArrangeOrder,
|
||||
InScalarPerVectorSeq,
|
||||
OutScalarPerVectorSeq,
|
||||
false>;
|
||||
|
||||
using GridwiseElementwiseOpSameInOutVectorDim = GridwiseElementwise<InGridDescTuple,
|
||||
OutGridDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
Block2TileMap,
|
||||
ElementwiseOperation,
|
||||
BlockSize,
|
||||
M0PerBlock,
|
||||
M1PerBlock,
|
||||
M0PerThread,
|
||||
M1PerThread,
|
||||
ThreadClusterArrangeOrder,
|
||||
InScalarPerVectorSeq,
|
||||
OutScalarPerVectorSeq,
|
||||
true>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op)
|
||||
|
||||
: lengths_(lengths),
|
||||
inStridesArray_(inStridesArray),
|
||||
outStridesArray_(outStridesArray),
|
||||
elementwise_op_(elementwise_op)
|
||||
{
|
||||
in_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
return static_cast<const DataType*>(in_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
out_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
return static_cast<DataType*>(out_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
}
|
||||
|
||||
InDataTypePointerTuple in_dev_buffers_;
|
||||
OutDataTypePointerTuple out_dev_buffers_;
|
||||
|
||||
std::array<index_t, NumDim> lengths_;
|
||||
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
|
||||
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
|
||||
|
||||
ElementwiseOperation elementwise_op_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
auto in_grid_desc_tuple = generate_tuple(
|
||||
[&](auto src_i) {
|
||||
// Use Strides from first tensor to assert that M0 dim and
|
||||
// M1 dim are the same for each tensor.
|
||||
return MakeDescriptor(arg.lengths_,
|
||||
arg.inStridesArray_[I0],
|
||||
arg.outStridesArray_[I0],
|
||||
arg.inStridesArray_[src_i]);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
auto out_grid_desc_tuple = generate_tuple(
|
||||
[&](auto dst_i) {
|
||||
return MakeDescriptor(arg.lengths_,
|
||||
arg.inStridesArray_[I0],
|
||||
arg.outStridesArray_[I0],
|
||||
arg.outStridesArray_[dst_i]);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
const index_t M0 = in_grid_desc_tuple.At(I0).GetLength(Number<I0>{});
|
||||
const index_t M1 = in_grid_desc_tuple.At(I0).GetLength(Number<I1>{});
|
||||
|
||||
const auto block_2_tile_map = Block2TileMap(M0, M1);
|
||||
const index_t grid_size = block_2_tile_map.CalculateGridSize(M0, M1);
|
||||
|
||||
const bool in_out_same_vector_dim = GetLowestStrideDim(arg.inStridesArray_[I0]) ==
|
||||
GetLowestStrideDim(arg.outStridesArray_[I0]);
|
||||
|
||||
const auto kernel = in_out_same_vector_dim
|
||||
? kernel_elementwise<GridwiseElementwiseOpSameInOutVectorDim,
|
||||
InGridDescTuple,
|
||||
OutGridDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
Block2TileMap,
|
||||
ElementwiseOperation>
|
||||
: kernel_elementwise<GridwiseElementwiseOp,
|
||||
InGridDescTuple,
|
||||
OutGridDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
Block2TileMap,
|
||||
ElementwiseOperation>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
in_grid_desc_tuple,
|
||||
out_grid_desc_tuple,
|
||||
arg.in_dev_buffers_,
|
||||
arg.out_dev_buffers_,
|
||||
block_2_tile_map,
|
||||
arg.elementwise_op_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
const index_t M0_dim = GetLowestStrideDim(arg.inStridesArray_[I0]);
|
||||
const index_t M1_dim = GetLowestStrideDim(arg.outStridesArray_[I0]);
|
||||
|
||||
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& strides,
|
||||
index_t scalarPerVector,
|
||||
index_t M_dim) {
|
||||
if(scalarPerVector == 1)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
if(strides[M_dim] == 1 && lengths[M_dim] % scalarPerVector == 0)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
bool is_valid = true;
|
||||
static_for<0, NumInput, 1>{}([&](auto I) {
|
||||
static_assert(M0PerThread % InScalarPerVectorSeq::At(I) == 0 &&
|
||||
M1PerThread % InScalarPerVectorSeq::At(I) == 0);
|
||||
is_valid &= IsScalarPerVectorValid(
|
||||
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I), M0_dim);
|
||||
});
|
||||
|
||||
static_for<0, NumOutput, 1>{}([&](auto I) {
|
||||
static_assert(M0PerThread % OutScalarPerVectorSeq::At(I) == 0 &&
|
||||
M1PerThread % OutScalarPerVectorSeq::At(I) == 0);
|
||||
is_valid &= IsScalarPerVectorValid(
|
||||
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I), M1_dim);
|
||||
});
|
||||
|
||||
return is_valid;
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op)
|
||||
{
|
||||
return Argument{lengths,
|
||||
inStridesArray,
|
||||
outStridesArray,
|
||||
in_dev_buffers,
|
||||
out_dev_buffers,
|
||||
elementwise_op};
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(lengths,
|
||||
inStridesArray,
|
||||
outStridesArray,
|
||||
in_dev_buffers,
|
||||
out_dev_buffers,
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceElementwiseImpl<";
|
||||
str << NumDim << ", ";
|
||||
str << BlockSize << ", ";
|
||||
str << M0PerBlock << ", ";
|
||||
str << M1PerBlock << ", ";
|
||||
str << M0PerThread << ", ";
|
||||
str << M1PerThread << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -322,6 +322,19 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceElementwiseNormalizationImpl<";
|
||||
str << NumDim << ", ";
|
||||
str << MPerThread << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -0,0 +1,714 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// 1. DequantB(K, N) = int2fp(B(K, N)) * scale(1, N)
|
||||
// 2. C(M, N) = A(M, K) * DequantB(K, N)
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ScaleDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumPrefetch,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::weight_only>
|
||||
struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
static constexpr auto AEnableLds_auto =
|
||||
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
|
||||
static constexpr auto BEnableLds_auto =
|
||||
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
// LDS bypass feature not implemented for dequantization pipeline.
|
||||
static constexpr auto AEnableLds_manu = true;
|
||||
static constexpr auto BEnableLds_manu = true;
|
||||
|
||||
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
|
||||
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
using DeviceOp = DeviceFpAintBGemm_Wmma_CShuffle;
|
||||
|
||||
// Describe how data read from Global memory
|
||||
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_KRow = 2;
|
||||
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
|
||||
const auto A_KWmma = K / WmmaK;
|
||||
|
||||
const auto M0 = M / MPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_n_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeScaleGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB = 0)
|
||||
{
|
||||
// assume Scale is [1, N]
|
||||
const auto scale_grid_desc_n_k = [&]() {
|
||||
const auto scale_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(scale_grid_desc_nraw_kraw);
|
||||
}();
|
||||
|
||||
const auto N = scale_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = scale_grid_desc_n_k.GetLength(I1);
|
||||
// When K = 1, it might be scale tensor.
|
||||
assert(K % K1 == 0 && K != 1);
|
||||
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
scale_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, 1)), // Reduce K1 = 1
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
scale_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(StrideC, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(I1, StrideC));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
// Gridwise descriptor, mapping to whole given provblem.
|
||||
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
|
||||
using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
|
||||
using ScaleGridDesc = decltype(MakeScaleGridDescriptor(1, 1, 0));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseFpAintBGemm_Wmma<
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ScaleDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc,
|
||||
BGridDesc,
|
||||
ScaleGridDesc,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
AEnableLds,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BEnableLds,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
NumPrefetch,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
const ScaleDataType* p_scale_grid,
|
||||
CDataType* p_c_grid,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t M01,
|
||||
index_t N01,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_scale_grid_{p_scale_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
a_grid_desc_{},
|
||||
b_grid_desc_{},
|
||||
scale_grid_desc_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
MRaw_{M},
|
||||
NRaw_{N},
|
||||
KRaw_{K}
|
||||
{
|
||||
a_grid_desc_ = DeviceOp::MakeAGridDescriptor(M, K, StrideA);
|
||||
b_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, StrideB);
|
||||
scale_grid_desc_ = DeviceOp::MakeScaleGridDescriptor(K, N, 0);
|
||||
c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC);
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_, b_grid_desc_, c_grid_desc_m_n_, block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
const ScaleDataType* p_scale_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc a_grid_desc_;
|
||||
BGridDesc b_grid_desc_;
|
||||
ScaleGridDesc scale_grid_desc_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
// for checking vector load/store
|
||||
index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
index_t KRaw_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
|
||||
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
const auto kernel = kernel_fpAintB_gemm_wmma<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ScaleDataType,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc>,
|
||||
remove_reference_t<DeviceOp::BGridDesc>,
|
||||
remove_reference_t<DeviceOp::ScaleGridDesc>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
has_main_k_block_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_scale_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.scale_grid_desc_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_navi3_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
printf("DeviceOp err: AccDataType");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("DeviceOp err: Arch");
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load/store
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector laod of B
|
||||
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector store of C
|
||||
// only support RowMajor for now
|
||||
if constexpr(is_same_v<CLayout, Row>)
|
||||
{
|
||||
if(arg.NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_m_n_, arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
const ScaleDataType* p_scale,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_scale,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_scale,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const ScaleDataType*>(p_scale),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<LoopScheduler, std::string> LoopSchedToString{
|
||||
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<PipelineVersion, std::string> PipelineVersionToString{
|
||||
{PipelineVersion::v1, "v1"},
|
||||
{PipelineVersion::v2, "v2"},
|
||||
{PipelineVersion::weight_only, "weight_only"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceFpAintBGemm_Wmma_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " AEnableLds: "
|
||||
<< AEnableLds << ", "
|
||||
<< "BEnableLds: "
|
||||
<< BEnableLds << ", "
|
||||
<< "NumPrefetch: "
|
||||
<< NumPrefetch << ", "
|
||||
<< "LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
<< "PipelineVersion: "
|
||||
<< PipelineVersionToString[PipelineVer];
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -27,21 +28,22 @@ template <typename ALayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumPrefetch,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
@@ -62,7 +64,6 @@ template <typename ALayout,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::index_t NumPrefetch = 1,
|
||||
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
@@ -83,68 +84,139 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
static constexpr auto AEnableLds_auto =
|
||||
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
|
||||
static constexpr auto BEnableLds_auto =
|
||||
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = false;
|
||||
static constexpr auto BEnableLds_manu = false;
|
||||
|
||||
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
|
||||
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
// Describe how data read from Global memory
|
||||
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_KRow = 2;
|
||||
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
|
||||
const auto A_KWmma = K / WmmaK;
|
||||
|
||||
const auto M0 = M / MPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
const auto b_grid_desc_n_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ELayout_>
|
||||
@@ -180,13 +252,13 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
}
|
||||
|
||||
// Gridwise descriptor, mapping to whole given provblem.
|
||||
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
|
||||
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
|
||||
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
|
||||
using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
using GridwiseOp = GridwiseGemmMultipleD_Wmma<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -195,8 +267,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
AGridDesc,
|
||||
BGridDesc,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
// ElementwiseOp Family
|
||||
@@ -207,9 +279,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
@@ -222,6 +294,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
AEnableLds,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -230,6 +303,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BEnableLds,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
@@ -262,8 +336,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
a_grid_desc{},
|
||||
b_grid_desc{},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
@@ -278,8 +352,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
NRaw_{N},
|
||||
KRaw_{K}
|
||||
{
|
||||
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
|
||||
a_grid_desc = DeviceOp::MakeAGridDescriptor(M, K, StrideA);
|
||||
b_grid_desc = DeviceOp::MakeBGridDescriptor(K, N, StrideB);
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
@@ -295,8 +369,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
|
||||
|
||||
if(GridwiseOp::CheckValidity(a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_n_k1_,
|
||||
if(GridwiseOp::CheckValidity(a_grid_desc,
|
||||
b_grid_desc,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
@@ -318,8 +392,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// Tensor Descriptors
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
AGridDesc a_grid_desc;
|
||||
BGridDesc b_grid_desc;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -352,24 +426,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
|
||||
arg.b_grid_desc,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
@@ -381,91 +439,64 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
const auto K = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I3) *
|
||||
arg.a_grid_desc.GetLength(I4) * arg.a_grid_desc.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
|
||||
float ave_time = 0;
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
|
||||
GridwiseOp,
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename GridwiseOp::DsGridPointer,
|
||||
EDataType,
|
||||
remove_reference_t<typename DeviceOp::AGridDesc>,
|
||||
remove_reference_t<typename DeviceOp::BGridDesc>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
|
||||
has_main_k_block_loop>; // Last Option is W/O
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_grid_desc,
|
||||
arg.b_grid_desc,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
|
||||
GridwiseOp,
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename GridwiseOp::DsGridPointer,
|
||||
EDataType,
|
||||
remove_reference_t<typename DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<typename DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
|
||||
true>; // Last Option is W/O
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
|
||||
GridwiseOp,
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename GridwiseOp::DsGridPointer,
|
||||
EDataType,
|
||||
remove_reference_t<typename DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<typename DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -575,8 +606,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
return GridwiseOp::CheckValidity(arg.a_grid_desc,
|
||||
arg.b_grid_desc,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
@@ -681,14 +712,18 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWMMA << ", "
|
||||
<< NPerWMMA << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " NumPrefetch: "
|
||||
<< " AEnableLds: "
|
||||
<< AEnableLds << ", "
|
||||
<< "BEnableLds: "
|
||||
<< BEnableLds << ", "
|
||||
<< "NumPrefetch: "
|
||||
<< NumPrefetch << ", "
|
||||
<< "LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
|
||||
@@ -498,6 +498,86 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
|
||||
{
|
||||
// check vector load/store
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// check vector laod of B
|
||||
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of Ds
|
||||
// only support RowMajor for now
|
||||
bool all_valid = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
if constexpr(!is_same_v<DLayout, Row>)
|
||||
{
|
||||
all_valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
if(!all_valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector store of E
|
||||
// only support RowMajor for now
|
||||
if constexpr(is_same_v<ELayout, Row>)
|
||||
{
|
||||
if(NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
@@ -505,87 +585,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load/store
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector laod of B
|
||||
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of Ds
|
||||
// only support RowMajor for now
|
||||
bool all_valid = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
if constexpr(!is_same_v<DLayout, Row>)
|
||||
{
|
||||
all_valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
if(!all_valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector store of E
|
||||
// only support RowMajor for now
|
||||
if constexpr(is_same_v<ELayout, Row>)
|
||||
{
|
||||
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
return IsSupported(arg.MRaw_, arg.NRaw_, arg.KRaw_) and
|
||||
GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
@@ -708,6 +709,178 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
template <class ADesc, class BDesc, class DsDesc, class EDesc>
|
||||
struct Descriptor
|
||||
{
|
||||
static constexpr auto ds_tuple()
|
||||
{
|
||||
return transform_tuples(
|
||||
[&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
|
||||
DsDesc{});
|
||||
}
|
||||
using AGridDesc_M_K =
|
||||
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{}))>;
|
||||
using BGridDesc_N_K =
|
||||
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{}))>;
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(ds_tuple())>;
|
||||
using EGridDesc_M_N =
|
||||
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{}))>;
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{})))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{})))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_tuple()))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
|
||||
using Block2ETileMap = remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(
|
||||
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
AGridDesc_M_K a_grid_desc_m_k;
|
||||
BGridDesc_N_K b_grid_desc_n_k;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
EGridDesc_M_N e_grid_desc_m_n;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op;
|
||||
BElementwiseOperation b_element_op;
|
||||
CDEElementwiseOperation cde_element_op;
|
||||
|
||||
// for checking vector load/store
|
||||
index_t MRaw;
|
||||
index_t NRaw;
|
||||
index_t KRaw;
|
||||
|
||||
bool has_main_k_block_loop = true;
|
||||
|
||||
constexpr Descriptor(ADesc a,
|
||||
BDesc b,
|
||||
DsDesc ds,
|
||||
EDesc e,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CDEElementwiseOperation cde_element_op_)
|
||||
: a_grid_desc_m_k{DeviceOp::matrix_padder.PadADescriptor_M_K(a)},
|
||||
b_grid_desc_n_k{DeviceOp::matrix_padder.PadBDescriptor_N_K(b)},
|
||||
ds_grid_desc_m_n{transform_tuples(
|
||||
[&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
|
||||
ds)},
|
||||
e_grid_desc_m_n{DeviceOp::matrix_padder.PadCDescriptor_M_N(e)},
|
||||
a_grid_desc_ak0_m_ak1{
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k)},
|
||||
b_grid_desc_bk0_n_bk1{
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock{
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
transform_tuples(
|
||||
[&](auto d) constexpr {
|
||||
return DeviceOp::matrix_padder.PadCDescriptor_M_N(d);
|
||||
},
|
||||
ds))},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock{
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n)},
|
||||
block_2_etile_map{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n)},
|
||||
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
cde_element_op{cde_element_op_},
|
||||
MRaw{e.GetLength(I0)},
|
||||
NRaw{e.GetLength(I1)},
|
||||
KRaw{a.GetLength(I1)}
|
||||
{
|
||||
}
|
||||
|
||||
constexpr bool IsValid() const
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(a_grid_desc_m_k,
|
||||
b_grid_desc_n_k,
|
||||
ds_grid_desc_m_n,
|
||||
e_grid_desc_m_n,
|
||||
block_2_etile_map) and
|
||||
IsSupported(MRaw, NRaw, KRaw);
|
||||
}
|
||||
|
||||
constexpr index_t GetBlockSize() const { return BlockSize; }
|
||||
|
||||
constexpr index_t GetGridSize() const
|
||||
{
|
||||
return block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
|
||||
}
|
||||
};
|
||||
|
||||
template <class ADesc, class BDesc, class DsDesc, class EDesc>
|
||||
static constexpr auto
|
||||
make_descriptor(ADesc a,
|
||||
BDesc b,
|
||||
DsDesc ds,
|
||||
EDesc e,
|
||||
AElementwiseOperation a_element_op = AElementwiseOperation{},
|
||||
BElementwiseOperation b_element_op = BElementwiseOperation{},
|
||||
CDEElementwiseOperation cde_element_op = CDEElementwiseOperation{})
|
||||
{
|
||||
return Descriptor<ADesc, BDesc, DsDesc, EDesc>(
|
||||
a, b, ds, e, a_element_op, b_element_op, cde_element_op);
|
||||
}
|
||||
|
||||
template <class Desc, class DsPointer>
|
||||
__device__ static void Run(const Desc& desc,
|
||||
const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid)
|
||||
{
|
||||
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
assert(desc.IsValid());
|
||||
if(desc.has_main_k_block_loop)
|
||||
{
|
||||
GridwiseGemm::template Run<true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
desc.cde_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_etile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
desc.cde_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_etile_map);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -33,13 +34,14 @@ template <typename ALayout,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumPrefetch,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
@@ -60,7 +62,6 @@ template <typename ALayout,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::index_t NumPrefetch = 1,
|
||||
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
@@ -76,68 +77,138 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
static constexpr auto AEnableLds_auto =
|
||||
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
|
||||
static constexpr auto BEnableLds_auto =
|
||||
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = false;
|
||||
static constexpr auto BEnableLds_manu = false;
|
||||
|
||||
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
|
||||
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
// Describe how data read from Global memory
|
||||
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_KRow = 2;
|
||||
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
|
||||
const auto A_KWmma = K / WmmaK;
|
||||
|
||||
const auto M0 = M / MPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
const auto b_grid_desc_n_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
|
||||
@@ -159,56 +230,58 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
}
|
||||
|
||||
// Gridwise descriptor, mapping to whole given provblem.
|
||||
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
|
||||
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
|
||||
using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma<
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
NumPrefetch,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_Wmma<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc,
|
||||
BGridDesc,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
AEnableLds,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BEnableLds,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
NumPrefetch,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -230,7 +303,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
a_grid_desc_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
@@ -244,19 +317,15 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
NRaw_{N},
|
||||
KRaw_{K}
|
||||
{
|
||||
a_grid_desc_k0_m_k1_ =
|
||||
DeviceGemmWmma_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ =
|
||||
DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
|
||||
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
|
||||
a_grid_desc_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ = DeviceGemmWmma_CShuffle::MakeBGridDescriptor(K, N, StrideB);
|
||||
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -268,8 +337,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
AGridDesc a_grid_desc_;
|
||||
BGridDesc b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
@@ -292,23 +361,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
@@ -320,79 +373,58 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
const auto K = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
|
||||
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
const auto kernel = kernel_gemm_wmma<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
has_main_k_block_loop>;
|
||||
|
||||
float ave_time = 0;
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_wmma<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
true>; // Last Option is W/O
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_wmma<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -413,13 +445,16 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
{
|
||||
if(ck::is_navi3_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
printf("DeviceOp err: AccDataType");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("DeviceOp err: Arch");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -485,7 +520,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
@@ -581,14 +616,18 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWMMA << ", "
|
||||
<< NPerWMMA << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " NumPrefetch: "
|
||||
<< " AEnableLds: "
|
||||
<< AEnableLds << ", "
|
||||
<< "BEnableLds: "
|
||||
<< BEnableLds << ", "
|
||||
<< "NumPrefetch: "
|
||||
<< NumPrefetch << ", "
|
||||
<< "LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
|
||||
@@ -60,7 +60,9 @@ template <typename ADataType,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
typename ComputeType = CDataType,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename LDSTypeA = ComputeType,
|
||||
typename LDSTypeB = ComputeType>
|
||||
|
||||
struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
@@ -81,6 +83,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
// TODO: should be exposed as Tparams.
|
||||
static constexpr index_t NumGemmKPrefetchStage = 1;
|
||||
|
||||
using ComputeTypeA = ComputeType;
|
||||
using ComputeTypeB = ComputeType;
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -125,7 +130,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVer,
|
||||
ComputeType>;
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
LDSTypeA,
|
||||
LDSTypeB>;
|
||||
|
||||
struct Argument : public GridwiseGemm::Argument
|
||||
{
|
||||
|
||||
@@ -196,7 +196,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
using EGridDesc_M_N = remove_cvref_t<tuple_element_t<3, ABDsEGridDesc>>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_Wmma<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -217,7 +217,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
KPerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
K1,
|
||||
@@ -232,6 +232,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
true,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -240,6 +241,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
true,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
|
||||
@@ -393,12 +393,14 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
using CShuffleDataType = AccDataType;
|
||||
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_Wmma<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CShuffleDataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
// InMemory Data Descriptor
|
||||
@@ -414,7 +416,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
KPerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
K1,
|
||||
@@ -429,6 +431,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
true,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -437,6 +440,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
true,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
|
||||
@@ -52,22 +52,23 @@ template <index_t NDimSpatial,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
@@ -88,7 +89,6 @@ template <index_t NDimSpatial,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
index_t NumGemmKPrefetchStage = 1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
@@ -109,11 +109,31 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr index_t KPerBlock = K0PerBlock * K1;
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = 16;
|
||||
|
||||
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
|
||||
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = true;
|
||||
static constexpr auto BEnableLds_manu = true;
|
||||
|
||||
static constexpr auto AEnableLds =
|
||||
AEnableLds_auto || AEnableLds_manu || (NumGemmKPrefetchStage > 1);
|
||||
static constexpr auto BEnableLds =
|
||||
BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1);
|
||||
|
||||
static constexpr auto conv_to_gemm_transformer =
|
||||
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
|
||||
@@ -122,17 +142,16 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
static auto MakeAGridDescriptor(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
|
||||
@@ -149,13 +168,44 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
const auto M = in_gemmm_gemmk_desc.GetLength(I0);
|
||||
const auto K = in_gemmm_gemmk_desc.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_KRow = 2;
|
||||
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
|
||||
const auto A_KWmma = K / WmmaK;
|
||||
|
||||
const auto M0 = M / MPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
static auto
|
||||
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
|
||||
static auto MakeBGridDescriptor(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
|
||||
@@ -164,7 +214,39 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
const auto wei_gemmn_gemmk_desc =
|
||||
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
|
||||
|
||||
return wei_gemmn_gemmk_desc;
|
||||
const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
|
||||
const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
@@ -197,53 +279,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
|
||||
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
|
||||
using AGridDesc =
|
||||
decltype(DeviceOp::MakeAGridDescriptor<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}));
|
||||
using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor<BLayout>({}, {}));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
|
||||
|
||||
// A desc for source in blockwise copy
|
||||
template <typename AGridDesc_M_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
|
||||
{
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK1 = K1;
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
// B desc for source in blockwise copy
|
||||
template <typename BGridDesc_N_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
|
||||
{
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK1 = K1;
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(DeviceOp::MakeAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}));
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
using GridwiseOp = GridwiseGemmMultipleD_Wmma<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -252,8 +295,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
DsDataType,
|
||||
EDataType,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
AGridDesc,
|
||||
BGridDesc,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
// ElementwiseOp Family
|
||||
@@ -264,9 +307,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
@@ -279,6 +322,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
AEnableLds,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -287,6 +331,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BEnableLds,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
@@ -327,23 +372,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_c_wis_lengths[0]},
|
||||
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)},
|
||||
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides)},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides)},
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
|
||||
a_grid_desc_{DeviceOp::MakeAGridDescriptor<ALayout>(a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)},
|
||||
b_grid_desc_{
|
||||
DeviceOp::MakeBGridDescriptor<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)},
|
||||
@@ -395,8 +438,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
|
||||
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
|
||||
std::cout << "A[M, K]: " << a_grid_desc_ << std::endl;
|
||||
std::cout << "B[N, K]: " << b_grid_desc_ << std::endl;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
|
||||
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
|
||||
@@ -411,14 +454,12 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
AGridDesc a_grid_desc_;
|
||||
BGridDesc b_grid_desc_;
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -465,8 +506,17 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
const index_t grid_size =
|
||||
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
const auto K = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
|
||||
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
@@ -480,8 +530,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::AGridDesc,
|
||||
DeviceOp::BGridDesc,
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
|
||||
@@ -501,8 +551,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_g_n_c_wis_lengths_[0], // Group count
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
@@ -670,8 +720,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
// check Gridwise GEMM
|
||||
return GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
return GridwiseOp::CheckValidity(arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
@@ -790,9 +840,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
<< KPerBlock << ", "
|
||||
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " AEnableLds: "
|
||||
<< AEnableLds << ", "
|
||||
<< "BEnableLds: "
|
||||
<< BEnableLds << ", "
|
||||
<< "ABlockTransferSrcScalarPerVector: "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferSrcScalarPerVector
|
||||
<< ">";
|
||||
<< "BBlockTransferSrcScalarPerVector: "
|
||||
<< BBlockTransferSrcScalarPerVector;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -650,22 +650,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
|
||||
constexpr auto Set = InMemoryDataOperationEnum::Set;
|
||||
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
ave_time =
|
||||
launch_kernel(integral_constant<bool, true>{},
|
||||
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
launch_kernel(integral_constant<bool, false>{},
|
||||
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced
|
||||
// in IsSupportedArgument function
|
||||
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
@@ -678,6 +665,39 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
integral_constant<InMemoryDataOperationEnum, Set>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
ave_time = launch_kernel(
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = launch_kernel(
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
ave_time =
|
||||
launch_kernel(integral_constant<bool, true>{},
|
||||
integral_constant<InMemoryDataOperationEnum, Set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
launch_kernel(integral_constant<bool, false>{},
|
||||
integral_constant<InMemoryDataOperationEnum, Set>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
@@ -718,6 +738,13 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
// For bf16 datatype only kbatch = 1 is supported since there is no AtomicAdd
|
||||
// instruction that supports bf16 and we cannot use splitk because of that
|
||||
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
|
||||
{
|
||||
supported = supported & (arg.k_batch_ == 1);
|
||||
}
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user