mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Implement padding and sanity checks for fused GEMM+GEMM (#376)
* GemmPadder and GemmGemmPadder
* proper padding using GemmGemmPadder
* test gemm_gemm padding
* properly check size K in IsSupportedArgument()
* properly check size requirement given SrcScalarPerVector in IsSupportedArgument()
* comment
* format
[ROCm/composable_kernel commit: f4047c9418]
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -188,6 +189,10 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
@@ -203,92 +208,18 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
|
||||
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both M and K
|
||||
assert(K % AK1 == 0);
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
const auto a_grid_desc_m_k =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad M, but not K
|
||||
assert(KRaw % AK1 == 0);
|
||||
|
||||
const auto AK0 = KRaw / AK1;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad K, but not M
|
||||
assert(K % AK1 == 0);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
const auto a_grid_desc_m_k = transform_tensor_descriptor(
|
||||
a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(MRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or K
|
||||
assert(KRaw % AK1 == 0);
|
||||
|
||||
const auto AK0 = KRaw / AK1;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(MRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
return transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
@@ -306,84 +237,18 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
}
|
||||
}();
|
||||
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
|
||||
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
|
||||
const auto NPad = N - NRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both N and K
|
||||
const auto BK0 = K / BK1;
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
const auto b_grid_desc_n_k =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(NRaw, NPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad N, but not K
|
||||
const auto BK0 = KRaw / BK1;
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad K, but not N
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
const auto b_grid_desc_n_k = transform_tensor_descriptor(
|
||||
b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad N or K
|
||||
const auto BK0 = KRaw / BK1;
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
return transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
|
||||
@@ -402,47 +267,19 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
}
|
||||
}();
|
||||
|
||||
const auto N = math::integer_divide_ceil(NRaw, Gemm1NPerBlock) * Gemm1NPerBlock;
|
||||
const auto K = math::integer_divide_ceil(KRaw, Gemm1KPerBlock) * Gemm1KPerBlock;
|
||||
const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
|
||||
|
||||
const auto NPad = N - NRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
const auto N = b1_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b1_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
// TODO: implement finer-grained padding
|
||||
if constexpr(GemmSpec == GemmSpecialization::Default)
|
||||
{
|
||||
const auto B1K0 = KRaw / B1K1;
|
||||
const auto B1K0 = K / B1K1;
|
||||
|
||||
const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b1_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b1_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// pad both B1N and B1K
|
||||
const auto B1K0 = K / B1K1;
|
||||
|
||||
const auto b1_grid_desc_n_k =
|
||||
transform_tensor_descriptor(b1_grid_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(NRaw, NPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b1_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b1_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
return transform_tensor_descriptor(
|
||||
b1_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
|
||||
@@ -460,47 +297,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto N = math::integer_divide_ceil(NRaw, Gemm1NPerBlock) * Gemm1NPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad M, but not N
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad N, but not M
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or N
|
||||
return c_grid_desc_mraw_nraw;
|
||||
}
|
||||
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
@@ -651,13 +448,15 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
b1_element_op_{b1_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
batch_count_(Batch),
|
||||
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC}
|
||||
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
|
||||
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
b1_grid_desc_bk0_n_bk1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
block_2_ctile_map_,
|
||||
raw_lengths_m_n_k_o_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -684,6 +483,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
CElementwiseOperation c_element_op_;
|
||||
index_t batch_count_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
|
||||
// For robust IsSupportedArgument() check
|
||||
std::vector<index_t> raw_lengths_m_n_k_o_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -697,7 +499,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
arg.block_2_ctile_map_,
|
||||
arg.raw_lengths_m_n_k_o_))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
@@ -787,11 +590,37 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
return false;
|
||||
}
|
||||
|
||||
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
|
||||
// vector is out of bounds
|
||||
const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
|
||||
const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
|
||||
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
|
||||
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
|
||||
const auto b_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
|
||||
const auto b1_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
|
||||
const auto c_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
arg.block_2_ctile_map_,
|
||||
arg.raw_lengths_m_n_k_o_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -903,7 +732,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
<< MPerBlock << ", "
|
||||
<< Gemm1NPerBlock << ", "
|
||||
<< Gemm1KPerBlock << ", "
|
||||
<< B1K1 << ">";
|
||||
<< B1K1 << ", "
|
||||
<< getGemmSpecializationString(GemmSpec) << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -9,6 +9,7 @@ namespace device {
|
||||
|
||||
enum struct GemmSpecialization
|
||||
{
|
||||
// Gemm
|
||||
Default,
|
||||
MPadding,
|
||||
NPadding,
|
||||
@@ -17,6 +18,15 @@ enum struct GemmSpecialization
|
||||
MKPadding,
|
||||
NKPadding,
|
||||
MNKPadding,
|
||||
// Gemm + Gemm
|
||||
OPadding,
|
||||
MOPadding,
|
||||
NOPadding,
|
||||
KOPadding,
|
||||
MNOPadding,
|
||||
MKOPadding,
|
||||
NKOPadding,
|
||||
MNKOPadding,
|
||||
};
|
||||
|
||||
inline std::string getGemmSpecializationString(const GemmSpecialization& s)
|
||||
@@ -31,6 +41,14 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
|
||||
case GemmSpecialization::MKPadding: return "MKPadding";
|
||||
case GemmSpecialization::NKPadding: return "NKPadding";
|
||||
case GemmSpecialization::MNKPadding: return "MNKPadding";
|
||||
case GemmSpecialization::OPadding: return "OPadding";
|
||||
case GemmSpecialization::MOPadding: return "MOPadding";
|
||||
case GemmSpecialization::NOPadding: return "NOPadding";
|
||||
case GemmSpecialization::KOPadding: return "KOPadding";
|
||||
case GemmSpecialization::MNOPadding: return "MNOPadding";
|
||||
case GemmSpecialization::MKOPadding: return "MKOPadding";
|
||||
case GemmSpecialization::NKOPadding: return "NKOPadding";
|
||||
case GemmSpecialization::MNKOPadding: return "MNKOPadding";
|
||||
default: return "Unrecognized specialization!";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,166 +12,176 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// For padding tensors without batch dimension
|
||||
template <bool PadM,
|
||||
bool PadN,
|
||||
typename TensorDesc_MRaw_NRaw,
|
||||
typename MPerBlockType,
|
||||
typename NPerBlockType,
|
||||
enable_if_t<TensorDesc_MRaw_NRaw::GetNumOfVisibleDimension() == 2, bool> = false>
|
||||
__host__ __device__ constexpr auto
|
||||
PadTensorDescriptor(const TensorDesc_MRaw_NRaw& tensor_desc_mraw_nraw,
|
||||
MPerBlockType MPerBlock,
|
||||
NPerBlockType NPerBlock)
|
||||
{
|
||||
const auto MRaw = tensor_desc_mraw_nraw.GetLength(Number<0>{});
|
||||
const auto NRaw = tensor_desc_mraw_nraw.GetLength(Number<1>{});
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
const auto MTransform = conditional_expr<PadM>(make_right_pad_transform(MRaw, MPad),
|
||||
make_pass_through_transform(MRaw));
|
||||
const auto NTransform = conditional_expr<PadN>(make_right_pad_transform(NRaw, NPad),
|
||||
make_pass_through_transform(NRaw));
|
||||
|
||||
return transform_tensor_descriptor(tensor_desc_mraw_nraw,
|
||||
make_tuple(MTransform, NTransform),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
// For padding tensors with batch dimension
|
||||
template <bool PadM,
|
||||
bool PadN,
|
||||
typename TensorDesc_GRaw_MRaw_NRaw,
|
||||
typename MPerBlockType,
|
||||
typename NPerBlockType,
|
||||
enable_if_t<TensorDesc_GRaw_MRaw_NRaw::GetNumOfVisibleDimension() == 3, bool> = false>
|
||||
__host__ __device__ constexpr auto
|
||||
PadTensorDescriptor(const TensorDesc_GRaw_MRaw_NRaw& tensor_desc_graw_mraw_nraw,
|
||||
MPerBlockType MPerBlock,
|
||||
NPerBlockType NPerBlock)
|
||||
{
|
||||
const auto GRaw = tensor_desc_graw_mraw_nraw.GetLength(Number<0>{});
|
||||
const auto MRaw = tensor_desc_graw_mraw_nraw.GetLength(Number<1>{});
|
||||
const auto NRaw = tensor_desc_graw_mraw_nraw.GetLength(Number<2>{});
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
const auto MTransform = conditional_expr<PadM>(make_right_pad_transform(MRaw, MPad),
|
||||
make_pass_through_transform(MRaw));
|
||||
const auto NTransform = conditional_expr<PadN>(make_right_pad_transform(NRaw, NPad),
|
||||
make_pass_through_transform(NRaw));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
tensor_desc_graw_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(GRaw), MTransform, NTransform),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
|
||||
// M/N/K/OPerTileType could be index_t or Number<>
|
||||
template <GemmSpecialization GemmSpec,
|
||||
typename MPerTileType,
|
||||
typename NPerTileType,
|
||||
typename KPerTileType,
|
||||
typename OPerTileType>
|
||||
struct GemmGemmPadder
|
||||
{
|
||||
// TODO: hard to scale; use mask instead
|
||||
static constexpr bool PadM =
|
||||
GemmSpec == GemmSpecialization::MPadding || GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::MOPadding || GemmSpec == GemmSpecialization::MNOPadding ||
|
||||
GemmSpec == GemmSpecialization::MKOPadding || GemmSpec == GemmSpecialization::MNKOPadding;
|
||||
static constexpr bool PadN =
|
||||
GemmSpec == GemmSpecialization::NPadding || GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::NOPadding || GemmSpec == GemmSpecialization::MNOPadding ||
|
||||
GemmSpec == GemmSpecialization::NKOPadding || GemmSpec == GemmSpecialization::MNKOPadding;
|
||||
static constexpr bool PadK =
|
||||
GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KOPadding || GemmSpec == GemmSpecialization::MKOPadding ||
|
||||
GemmSpec == GemmSpecialization::NKOPadding || GemmSpec == GemmSpecialization::MNKOPadding;
|
||||
static constexpr bool PadO =
|
||||
GemmSpec == GemmSpecialization::OPadding || GemmSpec == GemmSpecialization::MOPadding ||
|
||||
GemmSpec == GemmSpecialization::NOPadding || GemmSpec == GemmSpecialization::KOPadding ||
|
||||
GemmSpec == GemmSpecialization::MNOPadding || GemmSpec == GemmSpecialization::MKOPadding ||
|
||||
GemmSpec == GemmSpecialization::NKOPadding || GemmSpec == GemmSpecialization::MNKOPadding;
|
||||
|
||||
// A[M, K]
|
||||
template <typename ADesc_MRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
|
||||
{
|
||||
return PadTensorDescriptor<PadM, PadK>(a_desc_mraw_kraw, MPerTile_, KPerTile_);
|
||||
}
|
||||
|
||||
// B[K, N]
|
||||
template <typename BDesc_NRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
|
||||
{
|
||||
return PadTensorDescriptor<PadN, PadK>(b_desc_nraw_kraw, NPerTile_, KPerTile_);
|
||||
}
|
||||
|
||||
// B1[Gemm1N, Gemm1K] = B1[O, N]
|
||||
template <typename B1Desc_NRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadB1Descriptor_N_K(const B1Desc_NRaw_KRaw& b1_desc_nraw_kraw) const
|
||||
{
|
||||
return PadTensorDescriptor<PadO, PadN>(b1_desc_nraw_kraw, OPerTile_, NPerTile_);
|
||||
}
|
||||
|
||||
// C[M, Gemm1N] = C[M, O]
|
||||
template <typename CDesc_MRaw_NRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
|
||||
{
|
||||
return PadTensorDescriptor<PadM, PadO>(c_desc_mraw_nraw, MPerTile_, OPerTile_);
|
||||
}
|
||||
|
||||
MPerTileType MPerTile_;
|
||||
NPerTileType NPerTile_;
|
||||
KPerTileType KPerTile_;
|
||||
OPerTileType OPerTile_;
|
||||
};
|
||||
|
||||
// M/N/KPerTileType could be index_t or Number<>
|
||||
template <GemmSpecialization GemmSpec,
|
||||
typename MPerTileType,
|
||||
typename NPerTileType,
|
||||
typename KPerTileType>
|
||||
struct MatrixPadder
|
||||
struct GemmPadder
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr bool PadM =
|
||||
(GemmSpec == GemmSpecialization::MPadding || GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::MNKPadding);
|
||||
static constexpr bool PadN =
|
||||
(GemmSpec == GemmSpecialization::NPadding || GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding);
|
||||
static constexpr bool PadK =
|
||||
(GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding);
|
||||
|
||||
template <typename ADesc_MRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
|
||||
{
|
||||
const auto MRaw = a_desc_mraw_kraw.GetLength(I0);
|
||||
const auto KRaw = a_desc_mraw_kraw.GetLength(I1);
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both M and K
|
||||
return transform_tensor_descriptor(a_desc_mraw_kraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad M, but not K
|
||||
return transform_tensor_descriptor(
|
||||
a_desc_mraw_kraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(KRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad K, but not M
|
||||
return transform_tensor_descriptor(
|
||||
a_desc_mraw_kraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or K
|
||||
return a_desc_mraw_kraw;
|
||||
}
|
||||
return PadTensorDescriptor<PadM, PadK>(a_desc_mraw_kraw, MPerTile_, KPerTile_);
|
||||
}
|
||||
|
||||
template <typename BDesc_NRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
|
||||
{
|
||||
const auto NRaw = b_desc_nraw_kraw.GetLength(I0);
|
||||
const auto KRaw = b_desc_nraw_kraw.GetLength(I1);
|
||||
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
|
||||
|
||||
const auto NPad = N - NRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both N and K
|
||||
return transform_tensor_descriptor(b_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(NRaw, NPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad N, but not K
|
||||
return transform_tensor_descriptor(
|
||||
b_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(NRaw, NPad), make_pass_through_transform(KRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad K, but not N
|
||||
return transform_tensor_descriptor(
|
||||
b_desc_nraw_kraw,
|
||||
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad N or K
|
||||
return b_desc_nraw_kraw;
|
||||
}
|
||||
return PadTensorDescriptor<PadN, PadK>(b_desc_nraw_kraw, NPerTile_, KPerTile_);
|
||||
}
|
||||
|
||||
template <typename CDesc_MRaw_NRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
|
||||
{
|
||||
const auto MRaw = c_desc_mraw_nraw.GetLength(I0);
|
||||
const auto NRaw = c_desc_mraw_nraw.GetLength(I1);
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(c_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad M, but not N
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad N, but not M
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or N
|
||||
return c_desc_mraw_nraw;
|
||||
}
|
||||
return PadTensorDescriptor<PadM, PadN>(c_desc_mraw_nraw, MPerTile_, NPerTile_);
|
||||
}
|
||||
|
||||
MPerTileType MPerTile_;
|
||||
@@ -179,6 +189,15 @@ struct MatrixPadder
|
||||
KPerTileType KPerTile_;
|
||||
};
|
||||
|
||||
// Alias of GemmPadder; to deprecate
|
||||
template <GemmSpecialization GemmSpec,
|
||||
typename MPerTileType,
|
||||
typename NPerTileType,
|
||||
typename KPerTileType>
|
||||
struct MatrixPadder : public GemmPadder<GemmSpec, MPerTileType, NPerTileType, KPerTileType>
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -200,7 +200,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
const std::vector<index_t>& lengths_m_n_k_o)
|
||||
{
|
||||
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
@@ -216,6 +217,13 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
// K is rounded to nearest multiples of K1 during tensor transformation so instead get KRaw
|
||||
const auto KRaw = lengths_m_n_k_o[2];
|
||||
if(!(KRaw % AK1 == 0 && KRaw % BK1 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
|
||||
Gemm1N % Gemm1NPerBlock == 0))
|
||||
{
|
||||
|
||||
@@ -114,4 +114,18 @@ struct conditional<false, X, Y>
|
||||
template <bool predicate, class X, class Y>
|
||||
using conditional_t = typename conditional<predicate, X, Y>::type;
|
||||
|
||||
// z = predicate ? x : y
|
||||
template <bool predicate, typename X, typename Y>
|
||||
constexpr auto conditional_expr(X&& x, Y&& y)
|
||||
{
|
||||
if constexpr(predicate)
|
||||
{
|
||||
return std::forward<X>(x);
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::forward<Y>(y);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -26,6 +26,7 @@ using S = ck::Sequence<Is...>;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
|
||||
// c[g, m, n] = a[g, m, k] * b[g, n, k]
|
||||
using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = std::tuple<
|
||||
@@ -37,7 +38,9 @@ using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_inst
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
// Padded fallback kernel
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
|
||||
@@ -195,6 +195,12 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
// early fail when no instances are found
|
||||
if(op_ptrs.size() == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
|
||||
@@ -19,6 +19,74 @@ TYPED_TEST_SUITE(TestBatchedGemmGemmFP16, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16) { this->Run(); }
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{136, 128, 32, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 136, 32, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 40, 128, 1},
|
||||
{128, 128, 136, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 136, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{129, 128, 32, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 129, 32, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
// Currently expected that no kernels can support this case
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 33, 128, 1},
|
||||
{128, 128, 129, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
// If kernel B1Layout is RowMajor, expect not to support odd O size
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 129, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Bench_FP16)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
@@ -37,3 +105,44 @@ TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Bench_FP16)
|
||||
this->verify_ = false;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMatch)
|
||||
{
|
||||
int P = 120; // requires padding
|
||||
int Q = 128; // do not require padding
|
||||
|
||||
// IsSupported(M, N, K, O)
|
||||
// clang-format off
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(Q, Q, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MPadding>{}.IsSupported(P, Q, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NPadding>{}.IsSupported(Q, P, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNPadding>{}.IsSupported(P, P, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::OPadding>{}.IsSupported(Q, Q, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MOPadding>{}.IsSupported(P, Q, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NOPadding>{}.IsSupported(Q, P, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNOPadding>{}.IsSupported(P, P, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMismatch)
|
||||
{
|
||||
// IsSupported(M, N, K, O)
|
||||
// clang-format off
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
|
||||
// Kernel can't support odd K because K must be integer multiples of K1 values of either A or B
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
|
||||
// Kernel can't support odd O size because it must satisfy SizeO % B1SrcScalarPerVector == 0
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -4,8 +4,12 @@
|
||||
#include <iostream>
|
||||
|
||||
#include <vector>
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
#include "profiler/include/profile_batched_gemm_gemm_impl.hpp"
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
@@ -66,3 +70,120 @@ struct TestBatchedGemmGemm : public ::testing::Test
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ALayout = Row;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = F16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = PassThrough;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
|
||||
|
||||
using DeviceGemmGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
32, // Gemm1KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
2, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<8, 32, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
bool IsSupported(int M, int N, int K, int O)
|
||||
{
|
||||
auto gemm = DeviceGemmGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
|
||||
static_cast<B0DataType*>(nullptr),
|
||||
static_cast<B1DataType*>(nullptr),
|
||||
static_cast<CDataType*>(nullptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
0, // BatchCount
|
||||
0, // StrideA
|
||||
0, // StrideB0
|
||||
0, // StrideB1
|
||||
0, // StrideC
|
||||
0, // BatchStrideA
|
||||
0, // BatchStrideB0
|
||||
0, // BatchStrideB1
|
||||
0, // BatchStrideC
|
||||
PassThrough{}, // a_element_op
|
||||
PassThrough{}, // b0_element_op
|
||||
PassThrough{}, // acc0_element_op
|
||||
PassThrough{}, // b1_element_op
|
||||
PassThrough{}); // c_element_op
|
||||
|
||||
return gemm.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user