Add KBatch support for gemm_ab_scale (#2740)

* Add KBatch support for gemm_ab_scale

* Revert kernel parameters change

* Remove printing

* fix formatting

* fix check

* Use {} in if

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
Sami Remes
2025-10-09 07:33:16 +01:00
committed by GitHub
parent e99356dabc
commit 9d4bfe3932
5 changed files with 34 additions and 12 deletions

View File

@@ -58,6 +58,8 @@ struct DeviceGemmMultipleD_ABScale : public BaseOperator
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual void SetKBatch(BaseArgument* arg, int KBatch) const = 0;
};
template <typename ALayout,

View File

@@ -311,6 +311,12 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
}
};
void SetKBatch(BaseArgument* base_arg, int KBatch) const override
{
auto& arg = *dynamic_cast<Argument*>(base_arg);
arg.KBatch = KBatch;
}
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check