mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Fix gemm-softmax-gemm-permute padding cases (#409)
* fix example; make padding on by default in example; fix argument checks * fix Gemm1KPacK which has since regressed from PR #399
This commit is contained in:
@@ -594,10 +594,17 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
|
||||
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
// selected_mfma.k_per_blk <= B1K1 <= selected_mfma.group_size
|
||||
constexpr index_t Gemm1KPack = math::max(
|
||||
math::gcd(MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size, B1K1),
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
|
||||
// selected_mfma.k_per_blk <= Gemm1KPack
|
||||
//
|
||||
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
|
||||
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
|
||||
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
|
||||
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
|
||||
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
|
||||
// therefore we may just as well assign Gemm1KPack = group_size
|
||||
constexpr index_t Gemm1KPack =
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
|
||||
|
||||
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
|
||||
BlockSize,
|
||||
|
||||
@@ -645,10 +645,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
|
||||
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
// selected_mfma.k_per_blk <= B1K1 <= selected_mfma.group_size
|
||||
constexpr index_t Gemm1KPack = math::max(
|
||||
math::gcd(MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size, B1K1),
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
|
||||
// selected_mfma.k_per_blk <= Gemm1KPack
|
||||
//
|
||||
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
|
||||
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
|
||||
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
|
||||
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
|
||||
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
|
||||
// therefore we may just as well assign Gemm1KPack = group_size
|
||||
constexpr index_t Gemm1KPack =
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
|
||||
|
||||
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
|
||||
BlockSize,
|
||||
|
||||
Reference in New Issue
Block a user