mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Patch for bwd data comments (#174)
* change function name and way to set input zero
* change enable if
[ROCm/composable_kernel commit: 6717168c18]
This commit is contained in:
@@ -83,7 +83,7 @@ using ReferenceConvBwdDataInstance =
|
||||
OutElementOp,
|
||||
NumDimSpatial>;
|
||||
|
||||
void PrintUseMsg()
|
||||
void print_use_msg()
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n"
|
||||
@@ -99,7 +99,7 @@ void PrintUseMsg()
|
||||
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
|
||||
<< std::endl;
|
||||
}
|
||||
ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, char* argv[])
|
||||
ck::conv_util::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
|
||||
{
|
||||
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
|
||||
ck::conv_util::ConvParams params;
|
||||
@@ -144,8 +144,8 @@ ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, char* argv[])
|
||||
return params;
|
||||
}
|
||||
|
||||
HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
@@ -165,8 +165,8 @@ HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector<std::size_t>
|
||||
}
|
||||
}
|
||||
}
|
||||
HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
@@ -187,8 +187,8 @@ HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector<std::size_
|
||||
}
|
||||
}
|
||||
|
||||
HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
HostTensorDescriptor get_output_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
@@ -210,7 +210,7 @@ HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector<std::size_t
|
||||
}
|
||||
}
|
||||
|
||||
DeviceConvBwdDataBasePtr GetConvInstance(int num_dim_spatial)
|
||||
DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
@@ -256,15 +256,15 @@ int main(int argc, char* argv[])
|
||||
int cmdline_nargs = conv_args + 5;
|
||||
if(cmdline_nargs != argc)
|
||||
{
|
||||
PrintUseMsg();
|
||||
print_use_msg();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
params = ParseConvParams(num_dim_spatial, argv);
|
||||
params = parse_conv_params(num_dim_spatial, argv);
|
||||
}
|
||||
else if(argc != 1)
|
||||
{
|
||||
PrintUseMsg();
|
||||
print_use_msg();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -288,11 +288,13 @@ int main(int argc, char* argv[])
|
||||
std::end(output_spatial_lengths));
|
||||
|
||||
Tensor<InDataType> in_n_c_hi_wi_host_result(
|
||||
GetInputHostTensorDescriptor(input_dims, num_dim_spatial));
|
||||
get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
|
||||
Tensor<InDataType> in_n_c_hi_wi_device_result(
|
||||
GetInputHostTensorDescriptor(input_dims, num_dim_spatial));
|
||||
Tensor<WeiDataType> wei_k_c_y_x(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial));
|
||||
Tensor<OutDataType> out_n_k_ho_wo(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial));
|
||||
get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
|
||||
Tensor<WeiDataType> wei_k_c_y_x(
|
||||
get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
|
||||
Tensor<OutDataType> out_n_k_ho_wo(
|
||||
get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
|
||||
|
||||
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl;
|
||||
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl;
|
||||
@@ -318,11 +320,10 @@ int main(int argc, char* argv[])
|
||||
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
// reset input to zero
|
||||
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
|
||||
in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data());
|
||||
in_device_buf.SetZero();
|
||||
|
||||
// do GEMM
|
||||
auto conv = GetConvInstance(num_dim_spatial);
|
||||
auto conv = get_conv_instance(num_dim_spatial);
|
||||
auto invoker = conv->MakeInvokerPointer();
|
||||
auto argument =
|
||||
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
|
||||
Reference in New Issue
Block a user