mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Support NHWGC conv2d_bwd_weight (#769)
* Support NHWGC conv2d_bwd_weight
* Fix client example
* Fix client example
* Fix comments
* Redesign grouped_conv_bwd_weight instances
* Clang format fix
---------
Co-authored-by: zjing14 <zhangjing14@gmail.com>
[ROCm/composable_kernel commit: 1ee99dcaa6]
This commit is contained in:
@@ -101,13 +101,15 @@ template <ck::index_t NumDimSpatial,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
bool run_grouped_conv_bwd_weight(
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NumDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NumDimSpatial>& input_left_pads,
|
||||
@@ -157,6 +159,8 @@ bool run_grouped_conv_bwd_weight(
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -224,6 +228,8 @@ bool run_grouped_conv_bwd_weight(
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
|
||||
@@ -22,6 +22,15 @@ static constexpr ck::index_t C = 192;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Wi = 28;
|
||||
static constexpr ck::index_t Wo = 28;
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, C, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, K, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1};
|
||||
|
||||
int main()
|
||||
{
|
||||
@@ -31,7 +40,19 @@ int main()
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(G, N, K, C, {Wi}, {X}, {Wo}, {1}, {1}, {1}, {1})
|
||||
OutLayout>(G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
|
||||
@@ -25,6 +25,17 @@ static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 28;
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Hi, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Y, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Ho, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
|
||||
N * Hi * Wi * C, Hi* Wi* C, Wi* C, C, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
|
||||
N * Ho * Wo * K, Ho* Wo* K, Wo* K, K, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1};
|
||||
|
||||
int main()
|
||||
{
|
||||
@@ -34,8 +45,19 @@ int main()
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
G, N, K, C, {Hi, Wi}, {Y, X}, {Ho, Wo}, {1, 1}, {1, 1}, {1, 1}, {1, 1})
|
||||
OutLayout>(G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
|
||||
@@ -28,6 +28,17 @@ static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
|
||||
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
|
||||
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1, 1};
|
||||
|
||||
int main()
|
||||
{
|
||||
@@ -41,13 +52,15 @@ int main()
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
{Di, Hi, Wi},
|
||||
{Z, Y, X},
|
||||
{Do, Ho, Wo},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1})
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
|
||||
@@ -28,6 +28,17 @@ static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
|
||||
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
|
||||
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1, 1};
|
||||
|
||||
int main()
|
||||
{
|
||||
@@ -37,17 +48,20 @@ int main()
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
{Di, Hi, Wi},
|
||||
{Z, Y, X},
|
||||
{Do, Ho, Wo},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1})
|
||||
OutLayout>(
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
{Di, Hi, Wi},
|
||||
{Z, Y, X},
|
||||
{Do, Ho, Wo},
|
||||
{N * Di * Hi * Wi * C, Di * Hi * Wi * C, Hi * Wi * C, Wi * C, C, 1},
|
||||
{N * Do * Ho * Wo * K, Do * Ho * Wo * K, Ho * Wo * K, Wo * K, K, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user