Pass gemm_descs for grouped gemm via __constant__ buff (#232)

* moved gemm_descs_args into const buff

* use CK_CONSTANT_ADDRESS_SPACE instead of global constant

* clean

* moved hipMemAlloc outside of deviceOp

* add SetWorkSpacePointer

* fix ignore
This commit is contained in:
zjing14
2022-05-31 17:00:43 -05:00
committed by GitHub
parent 7b1e2c379e
commit b6eaf3eb7e
4 changed files with 111 additions and 111 deletions

View File

@@ -141,10 +141,15 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
auto c_element_op = PassThrough{};
// do GEMM
auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer();
auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer();
auto argument_ptr = groupedGemmPtr->MakeArgumentPointer(
p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(groupedGemmPtr->GetWorkSpaceSize(argument_ptr.get()));
groupedGemmPtr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
invoker_ptr->Run(argument_ptr.get());
for(std::size_t i = 0; i < gemm_shapes.size(); i++)