mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
added padding of K into gemm_v2r3 (#887)
* added kpad support into v2r3 * add generic instances * fixed comments * fixed mnk padding * Update device_batched_gemm_xdl.hpp --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -185,7 +185,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpecialization::MNPadding,
|
||||
GemmSpecialization::MNKPadding,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
@@ -315,11 +315,6 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
if(problem.K % K1 != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(problem);
|
||||
}
|
||||
|
||||
@@ -416,7 +411,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock
|
||||
<< K0PerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< MXdlPerWave << ", "
|
||||
<< NXdlPerWave << ", "
|
||||
<< ">"
|
||||
<< " NumGemmKPrefetchStage: "
|
||||
<< NumGemmKPrefetchStage << ", "
|
||||
|
||||
@@ -194,7 +194,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
StrideC{StrideC_},
|
||||
MPadded{CalculateMPadded(M_)},
|
||||
NPadded{CalculateNPadded(N_)},
|
||||
K0{CalculateK0(K)}
|
||||
K0{CalculateK0(K_)}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -383,7 +383,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
|
||||
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / (K0PerBlock * K1);
|
||||
const index_t num_loop = math::integer_divide_ceil(K, K0PerBlock * K1);
|
||||
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
@@ -840,7 +840,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
|
||||
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
|
||||
{
|
||||
const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock;
|
||||
const auto KPad = K0Pad * K1Value;
|
||||
|
||||
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_kpad,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)),
|
||||
make_right_pad_transform(M, MPad - M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
@@ -874,7 +892,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
|
||||
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
|
||||
{
|
||||
const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock;
|
||||
const auto KPad = K0Pad * K1Value;
|
||||
|
||||
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
|
||||
b_grid_desc_k_n,
|
||||
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_kpad_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)),
|
||||
make_right_pad_transform(N, NPad - N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_k_n,
|
||||
|
||||
Reference in New Issue
Block a user