mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Move SetZero functions inside the kernels for Grouped Conv (#2255)
* Disable SetZero before launch kernel for grouped conv fwd * Move set zero to kernel * wmma fix * fix --------- Co-authored-by: BrianHarrisonAMD <169072757+BrianHarrisonAMD@users.noreply.github.com>
This commit is contained in:
@@ -86,9 +86,6 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
out_device_buf.ToDevice(out.mData.data());
|
||||
wei_device_buf.ToDevice(wei.mData.data());
|
||||
|
||||
// reset input to zero
|
||||
in_device_buf.SetZero();
|
||||
|
||||
float max_accumulated_value = 0;
|
||||
if(do_verification)
|
||||
{
|
||||
@@ -136,9 +133,6 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// re-init output to zero before profiling next kernel
|
||||
in_device_buf.SetZero();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp"
|
||||
@@ -207,8 +206,6 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// using atomic add, so need to reset input
|
||||
wei_device_buf.SetZero();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
|
||||
@@ -155,9 +155,6 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// re-init output to zero before profiling next kernel
|
||||
out_device_buf.SetZero();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
Reference in New Issue
Block a user