diff --git a/profiler/include/profile_grouped_gemm_impl.hpp b/profiler/include/profile_grouped_gemm_impl.hpp index f3c0082452..92f45eccee 100644 --- a/profiler/include/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profile_grouped_gemm_impl.hpp @@ -232,6 +232,10 @@ void profile_grouped_gemm_impl(int do_verification, auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); + + gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { std::string gemm_name = gemm_ptr->GetTypeString();