mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Add client example of grouped conv2d backward weight (data type: fp16) (#498)
* Remove redundant CMake setting
* Extract common code from files
* Rename folder 'convnd' to 'conv'
* Use std::array<> to accept compile-time kwnown # of arguments
* Fix compilation error of tuning parameter
* In example, use same setting as unit-test
* Remove no-longer used include directive
* Add interface for grouped conv bwd weight
* Add group support for conv bwd weight
* Add grouped conv bwd weight example
* Use group parameter in example
* Rename example folder
* Remove non-grouped version example source files
* Rename device op template
* Add group support to convolution backward weight
* Remove debug messages
* Use smaller group size in example
* Use named variable as loop terminate condition
* Prettify example output message
* Enlarge used grid size
* Allow real grid size exceeds expected grid size
* Rename interface file
* Add client example for grouped conv2d bwd weight
* Fix wrong include directive
* Rename client example folder
[ROCm/composable_kernel commit: 38470e0497]
This commit is contained in:
@@ -3,9 +3,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <typeinfo>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
@@ -13,7 +14,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/convolution_backward_weight.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -26,32 +27,6 @@
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename DataType>
|
||||
void show_data_nhwc_layout(Tensor<DataType>& nhwc)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int n = 0; n < ck::type_convert<int>(nhwc.mDesc.GetLengths()[0]); n++)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int hi = 0; hi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[2]); hi++)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int wi = 0; wi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[3]); wi++)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int c = 0; c < ck::type_convert<int>(nhwc.mDesc.GetLengths()[1]); c++)
|
||||
{
|
||||
std::cout << static_cast<float>(nhwc(n, c, hi, wi)) << " ";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
@@ -59,12 +34,12 @@ template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
bool profile_conv_bwd_weight_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
ck::index_t split_k)
|
||||
bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
ck::index_t split_k)
|
||||
{
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
@@ -114,16 +89,14 @@ bool profile_conv_bwd_weight_impl(int do_verification,
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>{};
|
||||
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(input,
|
||||
weight_host_result,
|
||||
output,
|
||||
@@ -138,16 +111,16 @@ bool profile_conv_bwd_weight_impl(int do_verification,
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceConvBwdWeight<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
@@ -163,22 +136,41 @@ bool profile_conv_bwd_weight_impl(int do_verification,
|
||||
// profile device Conv instances
|
||||
bool all_pass = true;
|
||||
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads{};
|
||||
|
||||
auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); };
|
||||
|
||||
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths));
|
||||
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
|
||||
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
|
||||
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
|
||||
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
|
||||
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
|
||||
range_copy(conv_param.input_right_pads_, begin(input_right_pads));
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
conv_param.G_,
|
||||
conv_param.N_,
|
||||
conv_param.K_,
|
||||
conv_param.C_,
|
||||
conv_param.input_spatial_lengths_,
|
||||
conv_param.filter_spatial_lengths_,
|
||||
conv_param.output_spatial_lengths_,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
@@ -218,32 +210,29 @@ bool profile_conv_bwd_weight_impl(int do_verification,
|
||||
wei_device_buf.FromDevice(weight_device_result.mData.data());
|
||||
|
||||
bool pass =
|
||||
ck::utils::check_err(weight_host_result.mData, weight_device_result.mData);
|
||||
ck::utils::check_err(weight_device_result.mData, weight_host_result.mData);
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
std::cout << "Fail info:" << op_ptr->GetTypeString() << std::endl;
|
||||
std::cout << "Fail info: " << op_ptr->GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
all_pass &= pass;
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
std::cout << "in : ";
|
||||
show_data_nhwc_layout(output);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "wei: ";
|
||||
show_data_nhwc_layout(weight_host_result);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "out : ";
|
||||
show_data_nhwc_layout(input);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "wei_device: ";
|
||||
show_data_nhwc_layout(weight_device_result);
|
||||
std::cout << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "output : ", output.mData, ",") << std::endl;
|
||||
;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "weight (device): ", weight_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "weight (host): ", weight_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
;
|
||||
LogRangeAsType<float>(std::cout << "input: ", input.mData, ",") << std::endl;
|
||||
;
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user