mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Fixed GroupedGemmFixedNK with hipGraph (#1065)
* fixed examples; add async_mem_set * add stream to all deviceOp using SetWorkspace --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -59,7 +59,9 @@ struct BaseOperator
|
||||
|
||||
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
|
||||
|
||||
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const
|
||||
virtual void SetWorkSpacePointer(BaseArgument* p_arg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const
|
||||
{
|
||||
assert(p_arg);
|
||||
p_arg->p_workspace_ = p_workspace;
|
||||
|
||||
@@ -376,7 +376,9 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
|
||||
return (workspace_size);
|
||||
};
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
|
||||
void SetWorkSpacePointer(BaseArgument* pArg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
|
||||
|
||||
|
||||
@@ -354,7 +354,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
|
||||
return (workspace_size);
|
||||
};
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
|
||||
void SetWorkSpacePointer(BaseArgument* pArg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
|
||||
|
||||
|
||||
@@ -345,7 +345,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
|
||||
return (workspace_size);
|
||||
};
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
|
||||
void SetWorkSpacePointer(BaseArgument* pArg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
|
||||
|
||||
|
||||
@@ -821,7 +821,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
|
||||
return (workspace_size);
|
||||
};
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
|
||||
void SetWorkSpacePointer(BaseArgument* pArg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
|
||||
|
||||
|
||||
@@ -226,7 +226,9 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
|
||||
void SetWorkSpacePointer(BaseArgument* pArg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
|
||||
|
||||
|
||||
@@ -817,12 +817,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
return arg.group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>);
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const override
|
||||
void SetWorkSpacePointer(BaseArgument* p_arg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& stream_config = StreamConfig{}) const override
|
||||
{
|
||||
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
|
||||
p_arg_->p_workspace_ = p_workspace;
|
||||
|
||||
hip_check_error(hipMemset(p_workspace, 0, GetWorkSpaceSize(p_arg)));
|
||||
hip_check_error(
|
||||
hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_));
|
||||
}
|
||||
|
||||
static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
|
||||
|
||||
@@ -577,7 +577,9 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp
|
||||
return (workspace_size);
|
||||
};
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
|
||||
void SetWorkSpacePointer(BaseArgument* pArg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user