mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Patch for bwd data comments (#174)
* change function name and way to set input zero
* change enable if
[ROCm/composable_kernel commit: 6717168c18]
This commit is contained in:
@@ -83,7 +83,7 @@ using ReferenceConvBwdDataInstance =
|
||||
OutElementOp,
|
||||
NumDimSpatial>;
|
||||
|
||||
void PrintUseMsg()
|
||||
void print_use_msg()
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n"
|
||||
@@ -99,7 +99,7 @@ void PrintUseMsg()
|
||||
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
|
||||
<< std::endl;
|
||||
}
|
||||
ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, char* argv[])
|
||||
ck::conv_util::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
|
||||
{
|
||||
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
|
||||
ck::conv_util::ConvParams params;
|
||||
@@ -144,8 +144,8 @@ ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, char* argv[])
|
||||
return params;
|
||||
}
|
||||
|
||||
HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
@@ -165,8 +165,8 @@ HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector<std::size_t>
|
||||
}
|
||||
}
|
||||
}
|
||||
HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
@@ -187,8 +187,8 @@ HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector<std::size_
|
||||
}
|
||||
}
|
||||
|
||||
HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
HostTensorDescriptor get_output_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
@@ -210,7 +210,7 @@ HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector<std::size_t
|
||||
}
|
||||
}
|
||||
|
||||
DeviceConvBwdDataBasePtr GetConvInstance(int num_dim_spatial)
|
||||
DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
@@ -256,15 +256,15 @@ int main(int argc, char* argv[])
|
||||
int cmdline_nargs = conv_args + 5;
|
||||
if(cmdline_nargs != argc)
|
||||
{
|
||||
PrintUseMsg();
|
||||
print_use_msg();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
params = ParseConvParams(num_dim_spatial, argv);
|
||||
params = parse_conv_params(num_dim_spatial, argv);
|
||||
}
|
||||
else if(argc != 1)
|
||||
{
|
||||
PrintUseMsg();
|
||||
print_use_msg();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -288,11 +288,13 @@ int main(int argc, char* argv[])
|
||||
std::end(output_spatial_lengths));
|
||||
|
||||
Tensor<InDataType> in_n_c_hi_wi_host_result(
|
||||
GetInputHostTensorDescriptor(input_dims, num_dim_spatial));
|
||||
get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
|
||||
Tensor<InDataType> in_n_c_hi_wi_device_result(
|
||||
GetInputHostTensorDescriptor(input_dims, num_dim_spatial));
|
||||
Tensor<WeiDataType> wei_k_c_y_x(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial));
|
||||
Tensor<OutDataType> out_n_k_ho_wo(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial));
|
||||
get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
|
||||
Tensor<WeiDataType> wei_k_c_y_x(
|
||||
get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
|
||||
Tensor<OutDataType> out_n_k_ho_wo(
|
||||
get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
|
||||
|
||||
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl;
|
||||
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl;
|
||||
@@ -318,11 +320,10 @@ int main(int argc, char* argv[])
|
||||
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
// reset input to zero
|
||||
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
|
||||
in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data());
|
||||
in_device_buf.SetZero();
|
||||
|
||||
// do GEMM
|
||||
auto conv = GetConvInstance(num_dim_spatial);
|
||||
auto conv = get_conv_instance(num_dim_spatial);
|
||||
auto invoker = conv->MakeInvokerPointer();
|
||||
auto argument =
|
||||
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
|
||||
@@ -917,21 +917,21 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
|
||||
|
||||
} // function end
|
||||
|
||||
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
|
||||
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
|
||||
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1,
|
||||
|
||||
@@ -18,8 +18,8 @@ template <typename InDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ck::index_t NumDimSpatial = 2,
|
||||
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
|
||||
ck::index_t NumDimSpatial = 2,
|
||||
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
|
||||
struct ReferenceConvBwdData : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
|
||||
@@ -336,8 +336,7 @@ bool profile_convnd_bwd_data_impl(int do_verification,
|
||||
wei_device_buf.ToDevice(weights.mData.data());
|
||||
|
||||
// reset input to zero
|
||||
input_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
|
||||
in_device_buf.ToDevice(input_device_result.mData.data());
|
||||
in_device_buf.SetZero();
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user