Add MNK padding, M = 0 support into grouped_gemm (#539)

* add mnk padding, support m=0

* clean code

* clean code

Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
This commit is contained in:
zjing14
2022-12-15 15:07:24 -06:00
committed by GitHub
parent 1115117503
commit 0345963eef
6 changed files with 75 additions and 1 deletions

View File

@@ -373,12 +373,20 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
gemm_desc_kernel_arg_.reserve(group_count_);
skipped_group_count_ = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
const index_t M = gemm_descs[i].M_;
const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_;
if(M == 0)
{
skipped_group_count_++;
continue;
}
const index_t StrideA = gemm_descs[i].stride_A_;
const index_t StrideB = gemm_descs[i].stride_B_;
const index_t StrideC = gemm_descs[i].stride_C_;
@@ -470,6 +478,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
// private:
index_t group_count_;
index_t skipped_group_count_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation c_element_op_;
@@ -581,7 +591,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
if((ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) +
arg.skipped_group_count_) != arg.group_count_)
{
return false;
}