Add two stage grouped conv bwd weight kernel (#1280)

[ROCm/composable_kernel commit: 0b6b5d1785]
This commit is contained in:
Bartłomiej Kocot
2024-05-08 09:53:24 +02:00
committed by GitHub
parent b62c21c3b5
commit 68b2757f11
10 changed files with 1087 additions and 3 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
@@ -160,6 +160,10 @@ bool run_grouped_conv_bwd_weight(
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});