Add support for NGCHW in grouped conv fwd (#1499)

* Support NGCHW in grouped conv fwd

* Remove not needed variable

* Fixes
This commit is contained in:
Bartłomiej Kocot
2024-09-20 10:45:46 +02:00
committed by GitHub
parent 0c39954da9
commit 4ba52b35dc
27 changed files with 1620 additions and 305 deletions

View File

@@ -148,6 +148,11 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
bool pass = true;
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
// workspace_sz will be equal to 0 for other layout than NGCHW
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init output to zero before profiling next kernel