mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
add p_workspace to baseargument (#275)
This commit is contained in:
@@ -15,6 +15,8 @@ struct BaseArgument
|
||||
BaseArgument& operator=(const BaseArgument&) = default;
|
||||
|
||||
virtual ~BaseArgument() {}
|
||||
|
||||
void* p_workspace_ = nullptr;
|
||||
};
|
||||
|
||||
struct BaseInvoker
|
||||
@@ -42,7 +44,11 @@ struct BaseOperator
|
||||
|
||||
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
|
||||
|
||||
virtual void SetWorkSpacePointer(BaseArgument*, void*) const {}
|
||||
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const final
|
||||
{
|
||||
assert(p_arg);
|
||||
p_arg->p_workspace_ = p_workspace;
|
||||
}
|
||||
|
||||
virtual ~BaseOperator() {}
|
||||
};
|
||||
|
||||
@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdl
|
||||
{
|
||||
grid_size_ = 0;
|
||||
|
||||
gemm_descs_args_workspace_ = nullptr;
|
||||
p_workspace_ = nullptr;
|
||||
|
||||
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
|
||||
|
||||
@@ -437,8 +437,6 @@ struct DeviceGroupedGemmXdl
|
||||
|
||||
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
|
||||
|
||||
void* gemm_descs_args_workspace_;
|
||||
|
||||
index_t grid_size_;
|
||||
};
|
||||
|
||||
@@ -488,7 +486,7 @@ struct DeviceGroupedGemmXdl
|
||||
}
|
||||
|
||||
hipGetErrorString(
|
||||
hipMemcpy(arg.gemm_descs_args_workspace_,
|
||||
hipMemcpy(arg.p_workspace_,
|
||||
arg.gemm_desc_kernel_arg_.data(),
|
||||
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
|
||||
hipMemcpyHostToDevice));
|
||||
@@ -507,17 +505,17 @@ struct DeviceGroupedGemmXdl
|
||||
CElementwiseOperation,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.p_workspace_),
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -531,17 +529,17 @@ struct DeviceGroupedGemmXdl
|
||||
CElementwiseOperation,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.p_workspace_),
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
@@ -635,11 +633,6 @@ struct DeviceGroupedGemmXdl
|
||||
{
|
||||
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmDescKernelArg);
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
|
||||
{
|
||||
dynamic_cast<Argument*>(p_arg)->gemm_descs_args_workspace_ = workspace_ptr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
Reference in New Issue
Block a user