mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Implement padding and sanity checks for fused GEMM+GEMM (#376)
* GemmPadder and GemmGemmPadder * proper padding using GemmGemmPadder * test gemm_gemm padding * properly check size K in IsSupportedArgument() * properly check size requirement given SrcScalarPerVector in IsSupportedArgument() * comment * format
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -188,6 +189,10 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
@@ -203,92 +208,18 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
|
||||
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both M and K
|
||||
assert(K % AK1 == 0);
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
const auto a_grid_desc_m_k =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad M, but not K
|
||||
assert(KRaw % AK1 == 0);
|
||||
|
||||
const auto AK0 = KRaw / AK1;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad K, but not M
|
||||
assert(K % AK1 == 0);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
const auto a_grid_desc_m_k = transform_tensor_descriptor(
|
||||
a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(MRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or K
|
||||
assert(KRaw % AK1 == 0);
|
||||
|
||||
const auto AK0 = KRaw / AK1;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(MRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
return transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
@@ -306,84 +237,18 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
}
|
||||
}();
|
||||
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
|
||||
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
|
||||
const auto NPad = N - NRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both N and K
|
||||
const auto BK0 = K / BK1;
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
const auto b_grid_desc_n_k =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(NRaw, NPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad N, but not K
|
||||
const auto BK0 = KRaw / BK1;
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad K, but not N
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
const auto b_grid_desc_n_k = transform_tensor_descriptor(
|
||||
b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad N or K
|
||||
const auto BK0 = KRaw / BK1;
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
return transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
|
||||
@@ -402,47 +267,19 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
}
|
||||
}();
|
||||
|
||||
const auto N = math::integer_divide_ceil(NRaw, Gemm1NPerBlock) * Gemm1NPerBlock;
|
||||
const auto K = math::integer_divide_ceil(KRaw, Gemm1KPerBlock) * Gemm1KPerBlock;
|
||||
const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
|
||||
|
||||
const auto NPad = N - NRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
const auto N = b1_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b1_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
// TODO: implement finer-grained padding
|
||||
if constexpr(GemmSpec == GemmSpecialization::Default)
|
||||
{
|
||||
const auto B1K0 = KRaw / B1K1;
|
||||
const auto B1K0 = K / B1K1;
|
||||
|
||||
const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b1_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b1_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// pad both B1N and B1K
|
||||
const auto B1K0 = K / B1K1;
|
||||
|
||||
const auto b1_grid_desc_n_k =
|
||||
transform_tensor_descriptor(b1_grid_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(NRaw, NPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b1_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b1_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
return transform_tensor_descriptor(
|
||||
b1_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
|
||||
@@ -460,47 +297,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto N = math::integer_divide_ceil(NRaw, Gemm1NPerBlock) * Gemm1NPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad M, but not N
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad N, but not M
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or N
|
||||
return c_grid_desc_mraw_nraw;
|
||||
}
|
||||
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
@@ -651,13 +448,15 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
b1_element_op_{b1_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
batch_count_(Batch),
|
||||
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC}
|
||||
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
|
||||
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
b1_grid_desc_bk0_n_bk1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
block_2_ctile_map_,
|
||||
raw_lengths_m_n_k_o_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -684,6 +483,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
CElementwiseOperation c_element_op_;
|
||||
index_t batch_count_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
|
||||
// For robust IsSupportedArgument() check
|
||||
std::vector<index_t> raw_lengths_m_n_k_o_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -697,7 +499,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
arg.block_2_ctile_map_,
|
||||
arg.raw_lengths_m_n_k_o_))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
@@ -787,11 +590,37 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
return false;
|
||||
}
|
||||
|
||||
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
|
||||
// vector is out of bounds
|
||||
const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
|
||||
const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
|
||||
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
|
||||
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
|
||||
const auto b_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
|
||||
const auto b1_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
|
||||
const auto c_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
arg.block_2_ctile_map_,
|
||||
arg.raw_lengths_m_n_k_o_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -903,7 +732,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
<< MPerBlock << ", "
|
||||
<< Gemm1NPerBlock << ", "
|
||||
<< Gemm1KPerBlock << ", "
|
||||
<< B1K1 << ">";
|
||||
<< B1K1 << ", "
|
||||
<< getGemmSpecializationString(GemmSpec) << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -9,6 +9,7 @@ namespace device {
|
||||
|
||||
enum struct GemmSpecialization
|
||||
{
|
||||
// Gemm
|
||||
Default,
|
||||
MPadding,
|
||||
NPadding,
|
||||
@@ -17,6 +18,15 @@ enum struct GemmSpecialization
|
||||
MKPadding,
|
||||
NKPadding,
|
||||
MNKPadding,
|
||||
// Gemm + Gemm
|
||||
OPadding,
|
||||
MOPadding,
|
||||
NOPadding,
|
||||
KOPadding,
|
||||
MNOPadding,
|
||||
MKOPadding,
|
||||
NKOPadding,
|
||||
MNKOPadding,
|
||||
};
|
||||
|
||||
inline std::string getGemmSpecializationString(const GemmSpecialization& s)
|
||||
@@ -31,6 +41,14 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
|
||||
case GemmSpecialization::MKPadding: return "MKPadding";
|
||||
case GemmSpecialization::NKPadding: return "NKPadding";
|
||||
case GemmSpecialization::MNKPadding: return "MNKPadding";
|
||||
case GemmSpecialization::OPadding: return "OPadding";
|
||||
case GemmSpecialization::MOPadding: return "MOPadding";
|
||||
case GemmSpecialization::NOPadding: return "NOPadding";
|
||||
case GemmSpecialization::KOPadding: return "KOPadding";
|
||||
case GemmSpecialization::MNOPadding: return "MNOPadding";
|
||||
case GemmSpecialization::MKOPadding: return "MKOPadding";
|
||||
case GemmSpecialization::NKOPadding: return "NKOPadding";
|
||||
case GemmSpecialization::MNKOPadding: return "MNKOPadding";
|
||||
default: return "Unrecognized specialization!";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,166 +12,176 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// For padding tensors without batch dimension
|
||||
template <bool PadM,
|
||||
bool PadN,
|
||||
typename TensorDesc_MRaw_NRaw,
|
||||
typename MPerBlockType,
|
||||
typename NPerBlockType,
|
||||
enable_if_t<TensorDesc_MRaw_NRaw::GetNumOfVisibleDimension() == 2, bool> = false>
|
||||
__host__ __device__ constexpr auto
|
||||
PadTensorDescriptor(const TensorDesc_MRaw_NRaw& tensor_desc_mraw_nraw,
|
||||
MPerBlockType MPerBlock,
|
||||
NPerBlockType NPerBlock)
|
||||
{
|
||||
const auto MRaw = tensor_desc_mraw_nraw.GetLength(Number<0>{});
|
||||
const auto NRaw = tensor_desc_mraw_nraw.GetLength(Number<1>{});
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
const auto MTransform = conditional_expr<PadM>(make_right_pad_transform(MRaw, MPad),
|
||||
make_pass_through_transform(MRaw));
|
||||
const auto NTransform = conditional_expr<PadN>(make_right_pad_transform(NRaw, NPad),
|
||||
make_pass_through_transform(NRaw));
|
||||
|
||||
return transform_tensor_descriptor(tensor_desc_mraw_nraw,
|
||||
make_tuple(MTransform, NTransform),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
// For padding tensors with batch dimension
|
||||
template <bool PadM,
|
||||
bool PadN,
|
||||
typename TensorDesc_GRaw_MRaw_NRaw,
|
||||
typename MPerBlockType,
|
||||
typename NPerBlockType,
|
||||
enable_if_t<TensorDesc_GRaw_MRaw_NRaw::GetNumOfVisibleDimension() == 3, bool> = false>
|
||||
__host__ __device__ constexpr auto
|
||||
PadTensorDescriptor(const TensorDesc_GRaw_MRaw_NRaw& tensor_desc_graw_mraw_nraw,
|
||||
MPerBlockType MPerBlock,
|
||||
NPerBlockType NPerBlock)
|
||||
{
|
||||
const auto GRaw = tensor_desc_graw_mraw_nraw.GetLength(Number<0>{});
|
||||
const auto MRaw = tensor_desc_graw_mraw_nraw.GetLength(Number<1>{});
|
||||
const auto NRaw = tensor_desc_graw_mraw_nraw.GetLength(Number<2>{});
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
const auto MTransform = conditional_expr<PadM>(make_right_pad_transform(MRaw, MPad),
|
||||
make_pass_through_transform(MRaw));
|
||||
const auto NTransform = conditional_expr<PadN>(make_right_pad_transform(NRaw, NPad),
|
||||
make_pass_through_transform(NRaw));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
tensor_desc_graw_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(GRaw), MTransform, NTransform),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
|
||||
// M/N/K/OPerTileType could be index_t or Number<>
|
||||
template <GemmSpecialization GemmSpec,
|
||||
typename MPerTileType,
|
||||
typename NPerTileType,
|
||||
typename KPerTileType,
|
||||
typename OPerTileType>
|
||||
struct GemmGemmPadder
|
||||
{
|
||||
// TODO: hard to scale; use mask instead
|
||||
static constexpr bool PadM =
|
||||
GemmSpec == GemmSpecialization::MPadding || GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::MOPadding || GemmSpec == GemmSpecialization::MNOPadding ||
|
||||
GemmSpec == GemmSpecialization::MKOPadding || GemmSpec == GemmSpecialization::MNKOPadding;
|
||||
static constexpr bool PadN =
|
||||
GemmSpec == GemmSpecialization::NPadding || GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::NOPadding || GemmSpec == GemmSpecialization::MNOPadding ||
|
||||
GemmSpec == GemmSpecialization::NKOPadding || GemmSpec == GemmSpecialization::MNKOPadding;
|
||||
static constexpr bool PadK =
|
||||
GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KOPadding || GemmSpec == GemmSpecialization::MKOPadding ||
|
||||
GemmSpec == GemmSpecialization::NKOPadding || GemmSpec == GemmSpecialization::MNKOPadding;
|
||||
static constexpr bool PadO =
|
||||
GemmSpec == GemmSpecialization::OPadding || GemmSpec == GemmSpecialization::MOPadding ||
|
||||
GemmSpec == GemmSpecialization::NOPadding || GemmSpec == GemmSpecialization::KOPadding ||
|
||||
GemmSpec == GemmSpecialization::MNOPadding || GemmSpec == GemmSpecialization::MKOPadding ||
|
||||
GemmSpec == GemmSpecialization::NKOPadding || GemmSpec == GemmSpecialization::MNKOPadding;
|
||||
|
||||
// A[M, K]
|
||||
template <typename ADesc_MRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
|
||||
{
|
||||
return PadTensorDescriptor<PadM, PadK>(a_desc_mraw_kraw, MPerTile_, KPerTile_);
|
||||
}
|
||||
|
||||
// B[K, N]
|
||||
template <typename BDesc_NRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
|
||||
{
|
||||
return PadTensorDescriptor<PadN, PadK>(b_desc_nraw_kraw, NPerTile_, KPerTile_);
|
||||
}
|
||||
|
||||
// B1[Gemm1N, Gemm1K] = B1[O, N]
|
||||
template <typename B1Desc_NRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadB1Descriptor_N_K(const B1Desc_NRaw_KRaw& b1_desc_nraw_kraw) const
|
||||
{
|
||||
return PadTensorDescriptor<PadO, PadN>(b1_desc_nraw_kraw, OPerTile_, NPerTile_);
|
||||
}
|
||||
|
||||
// C[M, Gemm1N] = C[M, O]
|
||||
template <typename CDesc_MRaw_NRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
|
||||
{
|
||||
return PadTensorDescriptor<PadM, PadO>(c_desc_mraw_nraw, MPerTile_, OPerTile_);
|
||||
}
|
||||
|
||||
MPerTileType MPerTile_;
|
||||
NPerTileType NPerTile_;
|
||||
KPerTileType KPerTile_;
|
||||
OPerTileType OPerTile_;
|
||||
};
|
||||
|
||||
// M/N/KPerTileType could be index_t or Number<>
|
||||
template <GemmSpecialization GemmSpec,
|
||||
typename MPerTileType,
|
||||
typename NPerTileType,
|
||||
typename KPerTileType>
|
||||
struct MatrixPadder
|
||||
struct GemmPadder
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr bool PadM =
|
||||
(GemmSpec == GemmSpecialization::MPadding || GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::MNKPadding);
|
||||
static constexpr bool PadN =
|
||||
(GemmSpec == GemmSpecialization::NPadding || GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding);
|
||||
static constexpr bool PadK =
|
||||
(GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding);
|
||||
|
||||
template <typename ADesc_MRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
|
||||
{
|
||||
const auto MRaw = a_desc_mraw_kraw.GetLength(I0);
|
||||
const auto KRaw = a_desc_mraw_kraw.GetLength(I1);
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both M and K
|
||||
return transform_tensor_descriptor(a_desc_mraw_kraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad M, but not K
|
||||
return transform_tensor_descriptor(
|
||||
a_desc_mraw_kraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(KRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad K, but not M
|
||||
return transform_tensor_descriptor(
|
||||
a_desc_mraw_kraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or K
|
||||
return a_desc_mraw_kraw;
|
||||
}
|
||||
return PadTensorDescriptor<PadM, PadK>(a_desc_mraw_kraw, MPerTile_, KPerTile_);
|
||||
}
|
||||
|
||||
template <typename BDesc_NRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
|
||||
{
|
||||
const auto NRaw = b_desc_nraw_kraw.GetLength(I0);
|
||||
const auto KRaw = b_desc_nraw_kraw.GetLength(I1);
|
||||
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
|
||||
|
||||
const auto NPad = N - NRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both N and K
|
||||
return transform_tensor_descriptor(b_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(NRaw, NPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad N, but not K
|
||||
return transform_tensor_descriptor(
|
||||
b_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(NRaw, NPad), make_pass_through_transform(KRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad K, but not N
|
||||
return transform_tensor_descriptor(
|
||||
b_desc_nraw_kraw,
|
||||
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad N or K
|
||||
return b_desc_nraw_kraw;
|
||||
}
|
||||
return PadTensorDescriptor<PadN, PadK>(b_desc_nraw_kraw, NPerTile_, KPerTile_);
|
||||
}
|
||||
|
||||
template <typename CDesc_MRaw_NRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
|
||||
{
|
||||
const auto MRaw = c_desc_mraw_nraw.GetLength(I0);
|
||||
const auto NRaw = c_desc_mraw_nraw.GetLength(I1);
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(c_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad M, but not N
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad N, but not M
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or N
|
||||
return c_desc_mraw_nraw;
|
||||
}
|
||||
return PadTensorDescriptor<PadM, PadN>(c_desc_mraw_nraw, MPerTile_, NPerTile_);
|
||||
}
|
||||
|
||||
MPerTileType MPerTile_;
|
||||
@@ -179,6 +189,15 @@ struct MatrixPadder
|
||||
KPerTileType KPerTile_;
|
||||
};
|
||||
|
||||
// Alias of GemmPadder; to deprecate
|
||||
template <GemmSpecialization GemmSpec,
|
||||
typename MPerTileType,
|
||||
typename NPerTileType,
|
||||
typename KPerTileType>
|
||||
struct MatrixPadder : public GemmPadder<GemmSpec, MPerTileType, NPerTileType, KPerTileType>
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user