mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
* remove switch for NDimSpatial
* change in, out and wei name
* rename reference thumb function name
* remove test
[ROCm/composable_kernel commit: c0e95f6204]
This commit is contained in:
@@ -71,7 +71,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
{
|
||||
if constexpr(NumDimSpatial == 1)
|
||||
{
|
||||
auto f_nchw = [&](auto n, auto c, auto wi) {
|
||||
auto f_ncw = [&](auto n, auto c, auto wi) {
|
||||
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
|
||||
std::size_t X = arg.weight_.mDesc.GetLengths()[2];
|
||||
std::size_t Wo = arg.output_.mDesc.GetLengths()[2];
|
||||
@@ -108,7 +108,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_in);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
make_ParallelTensorFunctor(f_ncw,
|
||||
arg.input_.mDesc.GetLengths()[0],
|
||||
arg.input_.mDesc.GetLengths()[1],
|
||||
arg.input_.mDesc.GetLengths()[2])(
|
||||
@@ -182,7 +182,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
auto f_nchw = [&](auto n, auto c, auto di, auto hi, auto wi) {
|
||||
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
|
||||
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
|
||||
std::size_t Z = arg.weight_.mDesc.GetLengths()[2];
|
||||
std::size_t Y = arg.weight_.mDesc.GetLengths()[3];
|
||||
@@ -252,7 +252,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
arg.input_(n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
make_ParallelTensorFunctor(f_ncdhw,
|
||||
arg.input_.mDesc.GetLengths()[0],
|
||||
arg.input_.mDesc.GetLengths()[1],
|
||||
arg.input_.mDesc.GetLengths()[2],
|
||||
|
||||
@@ -120,7 +120,6 @@ HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector<std::siz
|
||||
case 1: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, OutLayout{});
|
||||
}
|
||||
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
@@ -274,13 +273,13 @@ bool profile_convnd_bwd_data_impl(int do_verification,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
const std::vector<ck::index_t>& input_spatial_lengths,
|
||||
const std::vector<ck::index_t>& filter_spatial_lengths,
|
||||
const std::vector<ck::index_t>& output_spatial_lengths,
|
||||
const std::vector<ck::index_t>& conv_filter_strides,
|
||||
const std::vector<ck::index_t>& conv_filter_dilations,
|
||||
const std::vector<ck::index_t>& input_left_pads,
|
||||
const std::vector<ck::index_t>& input_right_pads)
|
||||
{
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
@@ -304,51 +303,50 @@ bool profile_convnd_bwd_data_impl(int do_verification,
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
|
||||
Tensor<InDataType> in_n_c_hi_wi_host_result(
|
||||
Tensor<InDataType> input_host_result(
|
||||
get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
|
||||
Tensor<InDataType> in_n_c_hi_wi_device_result(
|
||||
Tensor<InDataType> input_device_result(
|
||||
get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
|
||||
Tensor<WeiDataType> wei_k_c_y_x(
|
||||
Tensor<WeiDataType> weights(
|
||||
get_filters_host_tensor_descriptor<WeiLayout>(filter_dims, NDimSpatial));
|
||||
Tensor<OutDataType> out_n_k_ho_wo(
|
||||
Tensor<OutDataType> output(
|
||||
get_output_host_ensor_descriptor<OutLayout>(output_dims, NDimSpatial));
|
||||
|
||||
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl;
|
||||
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl;
|
||||
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
|
||||
std::cout << "input: " << input_host_result.mDesc << std::endl;
|
||||
std::cout << "weights: " << weights.mDesc << std::endl;
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
|
||||
output.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
output.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
weights.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) *
|
||||
in_n_c_hi_wi_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace());
|
||||
|
||||
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_device_buf.ToDevice(output.mData.data());
|
||||
wei_device_buf.ToDevice(weights.mData.data());
|
||||
|
||||
// reset input to zero
|
||||
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
|
||||
in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data());
|
||||
input_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
|
||||
in_device_buf.ToDevice(input_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto RunReference = [&](auto& ref_conv) {
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result,
|
||||
wei_k_c_y_x,
|
||||
out_n_k_ho_wo,
|
||||
auto ref_argument = ref_conv.MakeArgument(input_host_result,
|
||||
weights,
|
||||
output,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -358,48 +356,16 @@ bool profile_convnd_bwd_data_impl(int do_verification,
|
||||
OutElementOp{});
|
||||
ref_invoker.Run(ref_argument);
|
||||
};
|
||||
switch(NDimSpatial)
|
||||
{
|
||||
case 3: {
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
3>();
|
||||
RunReference(ref_conv);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
2>();
|
||||
RunReference(ref_conv);
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
1>();
|
||||
RunReference(ref_conv);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
NDimSpatial>();
|
||||
RunReference(ref_conv);
|
||||
}
|
||||
|
||||
// add device Conv instances
|
||||
@@ -468,9 +434,9 @@ bool profile_convnd_bwd_data_impl(int do_verification,
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data());
|
||||
in_device_buf.FromDevice(input_device_result.mData.data());
|
||||
|
||||
if(!check_out(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result))
|
||||
if(!check_out(input_host_result, input_device_result))
|
||||
{
|
||||
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
|
||||
|
||||
@@ -481,24 +447,24 @@ bool profile_convnd_bwd_data_impl(int do_verification,
|
||||
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
check_error(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result);
|
||||
check_error(input_host_result, input_device_result);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
std::cout << "in : ";
|
||||
show_data_nhwc_layout(out_n_k_ho_wo);
|
||||
show_data_nhwc_layout(output);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "wei: ";
|
||||
show_data_nhwc_layout(wei_k_c_y_x);
|
||||
show_data_nhwc_layout(weights);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "out_host : ";
|
||||
show_data_nhwc_layout(in_n_c_hi_wi_host_result);
|
||||
show_data_nhwc_layout(input_host_result);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "out_device: ";
|
||||
show_data_nhwc_layout(in_n_c_hi_wi_device_result);
|
||||
show_data_nhwc_layout(input_device_result);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user