mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Grouped conv bwd data NGCHW (#1967)
* Grouped conv bwd data NGCHW * fixes * fix * Improvements * Fix * Fix * add client example
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -125,6 +125,11 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
bool pass = true;
|
||||
|
||||
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
|
||||
// workspace_sz will be equal to 0 for other layout than NGCHW
|
||||
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// re-init output to zero before profiling next kernel
|
||||
|
||||
Reference in New Issue
Block a user