mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Add KBatch support for gemm_ab_scale (#2740)
* Add KBatch support for gemm_ab_scale
* Revert kernel parameters change
* Remove printing
* fix formatting
* fix check
* Use {} in if
---------
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -58,6 +58,8 @@ struct DeviceGemmMultipleD_ABScale : public BaseOperator
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual void SetKBatch(BaseArgument* arg, int KBatch) const = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
|
||||
@@ -311,6 +311,12 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
}
|
||||
};
|
||||
|
||||
void SetKBatch(BaseArgument* base_arg, int KBatch) const override
|
||||
{
|
||||
auto& arg = *dynamic_cast<Argument*>(base_arg);
|
||||
arg.KBatch = KBatch;
|
||||
}
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
|
||||
Reference in New Issue
Block a user