Pass gemm_descs for grouped gemm via __constant__ buff (#232)

* moved gemm_descs_args into const buff

* use CK_CONSTANT_ADDRESS_SPACE instead of global constant

* clean

* moved hipMemAlloc outside of deviceOp

* add SetWorkSpacePointer

* fix ignore
This commit is contained in:
zjing14
2022-05-31 17:00:43 -05:00
committed by GitHub
parent 7b1e2c379e
commit b6eaf3eb7e
4 changed files with 111 additions and 111 deletions

View File

@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
exit(0);
}
int group_count = 4;
int group_count = rand() % 16 + 1;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
@@ -189,12 +189,17 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument =
gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(