mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add two stage grouped conv bwd weight kernel (#1280)
[ROCm/composable_kernel commit: 0b6b5d1785]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iomanip>
|
||||
@@ -160,6 +160,10 @@ bool run_grouped_conv_bwd_weight(
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
SimpleDeviceMem workspace_dev(workspace_sz);
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
Reference in New Issue
Block a user