mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Add client example of grouped conv2d backward weight (data type: fp16) (#498)
* Remove redundant CMake setting
* Extract common code from files
* Rename folder 'convnd' to 'conv'
* Use std::array<> to accept compile-time kwnown # of arguments
* Fix compilation error of tuning parameter
* In example, use same setting as unit-test
* Remove no-longer used include directive
* Add interface for grouped conv bwd weight
* Add group support for conv bwd weight
* Add grouped conv bwd weight example
* Use group parameter in example
* Rename example folder
* Remove non-grouped version example source files
* Rename device op template
* Add group support to convolution backward weight
* Remove debug messages
* Use smaller group size in example
* Use named variable as loop terminate condition
* Prettify example output message
* Enlarge used grid size
* Allow real grid size exceeds expected grid size
* Rename interface file
* Add client example for grouped conv2d bwd weight
* Fix wrong include directive
* Rename client example folder
[ROCm/composable_kernel commit: 38470e0497]
This commit is contained in:
@@ -1,19 +1,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <initializer_list>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "profiler/include/profile_conv_bwd_weight_impl.hpp"
|
||||
#include "profiler/include/profile_grouped_conv_bwd_weight_impl.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
enum struct ConvLayout
|
||||
{
|
||||
NCHW_KCYX_NKHW, // 0
|
||||
NHWC_KYXC_NHWK, // 1
|
||||
GNCHW_GKCYX_GNKHW, // 0
|
||||
GNHWC_GKYXC_GNHWK, // 1
|
||||
};
|
||||
|
||||
enum struct ConvDataType
|
||||
@@ -25,24 +25,25 @@ enum struct ConvDataType
|
||||
|
||||
static void print_helper_msg()
|
||||
{
|
||||
std::cout
|
||||
<< "arg1: tensor operation (conv_bwd_weight: Convolution Backward Weight\n"
|
||||
<< "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n"
|
||||
<< " 1: Input fp16, Weight fp16, Output fp16\n"
|
||||
<< " 2: Input bf16, Weight fp32, Output bf16)\n"
|
||||
<< "arg3: tensor layout (0: Input[N, C, Hi, Wi], Weight[K, C, Y, X], Output[N, K, Ho, Wo]\n"
|
||||
<< " 1: Input[N, Hi, Wi, C], Weight[K, Y, X, C], Output[N, Ho, Wo, K]\n"
|
||||
<< "arg4: verification (0: no, 1: yes)\n"
|
||||
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
|
||||
<< "arg6: print tensor value (0: no; 1: yes)\n"
|
||||
<< "arg7: time kernel (0: no, 1: yes)\n"
|
||||
<< ck::utils::conv::get_conv_param_parser_helper_msg() << " SplitK\n"
|
||||
<< std::endl;
|
||||
std::cout << "arg1: tensor operation (conv_bwd_weight: Convolution Backward Weight\n"
|
||||
<< "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n"
|
||||
<< " 1: Input fp16, Weight fp16, Output fp16\n"
|
||||
<< " 2: Input bf16, Weight fp32, Output bf16)\n"
|
||||
<< "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, "
|
||||
"N, K, Ho, Wo]\n"
|
||||
<< " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
|
||||
"N, Ho, Wo, K]\n"
|
||||
<< "arg4: verification (0: no, 1: yes)\n"
|
||||
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
|
||||
<< "arg6: print tensor value (0: no; 1: yes)\n"
|
||||
<< "arg7: time kernel (0: no, 1: yes)\n"
|
||||
<< ck::utils::conv::get_conv_param_parser_helper_msg() << " SplitK\n"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int profile_conv_bwd_weight(int argc, char* argv[])
|
||||
int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
{
|
||||
// 8 for control, 1 for num_dim_spatial
|
||||
if(argc < 9)
|
||||
@@ -75,17 +76,17 @@ int profile_conv_bwd_weight(int argc, char* argv[])
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
|
||||
using NWC = ck::tensor_layout::convolution::NWC;
|
||||
using NHWC = ck::tensor_layout::convolution::NHWC;
|
||||
using NDHWC = ck::tensor_layout::convolution::NDHWC;
|
||||
using GNWC = ck::tensor_layout::convolution::GNWC;
|
||||
using GNHWC = ck::tensor_layout::convolution::GNHWC;
|
||||
using GNDHWC = ck::tensor_layout::convolution::GNDHWC;
|
||||
|
||||
using KXC = ck::tensor_layout::convolution::KXC;
|
||||
using KYXC = ck::tensor_layout::convolution::KYXC;
|
||||
using KZYXC = ck::tensor_layout::convolution::KZYXC;
|
||||
using GKXC = ck::tensor_layout::convolution::GKXC;
|
||||
using GKYXC = ck::tensor_layout::convolution::GKYXC;
|
||||
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
using NWK = ck::tensor_layout::convolution::NWK;
|
||||
using NHWK = ck::tensor_layout::convolution::NHWK;
|
||||
using NDHWK = ck::tensor_layout::convolution::NDHWK;
|
||||
using GNWK = ck::tensor_layout::convolution::GNWK;
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
|
||||
|
||||
constexpr auto I1 = ck::Number<1>{};
|
||||
constexpr auto I2 = ck::Number<2>{};
|
||||
@@ -108,64 +109,64 @@ int profile_conv_bwd_weight(int argc, char* argv[])
|
||||
using WeiDataType = decltype(wei_type);
|
||||
using OutDataType = decltype(out_type);
|
||||
|
||||
bool pass = ck::profiler::profile_conv_bwd_weight_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType>(
|
||||
bool pass = ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType>(
|
||||
do_verification, init_method, do_log, time_kernel, params, split_k);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
};
|
||||
|
||||
if(num_dim_spatial == 1 && layout == ConvLayout::NHWC_KYXC_NHWK)
|
||||
if(num_dim_spatial == 1 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I1, NWC{}, KXC{}, NWK{}, F32{}, F32{}, F32{});
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I1, NWC{}, KXC{}, NWK{}, F16{}, F16{}, F16{});
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_F32_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, F32{}, BF16{});
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWC_KYXC_NHWK)
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I2, NHWC{}, KYXC{}, NHWK{}, F32{}, F32{}, F32{});
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I2, NHWC{}, KYXC{}, NHWK{}, F16{}, F16{}, F16{});
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_F32_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, F32{}, BF16{});
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWC_KYXC_NHWK)
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, F32{}, F32{}, F32{});
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, F16{}, F16{}, F16{});
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_F32_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, F32{}, BF16{});
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,8 +18,8 @@ int profile_conv_fwd(int, char*[]);
|
||||
int profile_conv_fwd_bias_relu(int, char*[]);
|
||||
int profile_conv_fwd_bias_relu_add(int, char*[]);
|
||||
int profile_conv_bwd_data(int, char*[]);
|
||||
int profile_conv_bwd_weight(int, char*[]);
|
||||
int profile_grouped_conv_fwd(int, char*[]);
|
||||
int profile_grouped_conv_bwd_weight(int, char*[]);
|
||||
int profile_softmax(int, char*[]);
|
||||
int profile_layernorm(int, char*[]);
|
||||
int profile_groupnorm(int, char*[]);
|
||||
@@ -43,8 +43,8 @@ static void print_helper_message()
|
||||
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n"
|
||||
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n"
|
||||
" conv_bwd_data: Convolution Backward Data\n"
|
||||
" conv_bwd_weight: Convolution Backward Weight\n"
|
||||
" grouped_conv_fwd: Grouped Convolution Forward\n"
|
||||
" grouped_conv_bwd_weight: Grouped Convolution Backward Weight\n"
|
||||
" softmax: Softmax\n"
|
||||
" reduce: Reduce\n");
|
||||
// clang-format on
|
||||
@@ -118,14 +118,14 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return profile_conv_bwd_data(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "conv_bwd_weight") == 0)
|
||||
{
|
||||
return profile_conv_bwd_weight(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "grouped_conv_fwd") == 0)
|
||||
{
|
||||
return profile_grouped_conv_fwd(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "conv_bwd_weight") == 0)
|
||||
{
|
||||
return profile_grouped_conv_bwd_weight(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "reduce") == 0)
|
||||
{
|
||||
return profile_reduce(argc, argv);
|
||||
|
||||
Reference in New Issue
Block a user