mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Pass gemm_descs for grouped gemm via __constant__ buff (#232)
* moved gemm_descs_args into const buff * use CK_CONSTANT_ADDRESS_SPACE instead of global constant * clean * moved hipMemAlloc outside of deviceOp * add SetWorkSpacePointer * fix ignore
This commit is contained in:
@@ -141,10 +141,15 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
|
||||
auto c_element_op = PassThrough{};
|
||||
|
||||
// do GEMM
|
||||
auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer();
|
||||
auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer();
|
||||
|
||||
auto argument_ptr = groupedGemmPtr->MakeArgumentPointer(
|
||||
p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
DeviceMem gemm_desc_workspace(groupedGemmPtr->GetWorkSpaceSize(argument_ptr.get()));
|
||||
|
||||
groupedGemmPtr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
|
||||
|
||||
invoker_ptr->Run(argument_ptr.get());
|
||||
|
||||
for(std::size_t i = 0; i < gemm_shapes.size(); i++)
|
||||
|
||||
Reference in New Issue
Block a user