Support broadcast for bias in grouped conv fwd (#1081)

* Support broadcast for bias in grouped conv fwd

* Fix comment

* Comment fixes

* Remove GK layout
This commit is contained in:
Bartłomiej Kocot
2023-12-08 11:07:42 +01:00
committed by GitHub
parent d939411dae
commit f836984891
15 changed files with 371 additions and 55 deletions

View File

@@ -16,6 +16,7 @@
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using BiasLayout = ck::tensor_layout::convolution::G_K;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAddScaleAddRelu = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu;
@@ -64,6 +65,9 @@ int execute_conv_fwd_scaleadd_scaleadd_relu()
std::array<ck::index_t, 6> out_lengths{G, N, K, Do, Ho, Wo};
std::array<ck::index_t, 6> out_strides{
K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K};
// Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW)
std::array<ck::index_t, 6> bias_lengths{G, 1, K, 1, 1, 1};
std::array<ck::index_t, 6> bias_strides{K, 0, 1, 0, 0, 0};
std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1, 1};
std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1, 1};
@@ -74,13 +78,13 @@ int execute_conv_fwd_scaleadd_scaleadd_relu()
SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C);
SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K);
SimpleDeviceMem d0(sizeof(std::tuple_element_t<0, DDataTypes>) * N * Do * Ho * Wo * G * K);
SimpleDeviceMem d1(sizeof(std::tuple_element_t<1, DDataTypes>) * N * Do * Ho * Wo * G * K);
SimpleDeviceMem d1(sizeof(std::tuple_element_t<1, DDataTypes>) * G * K);
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<OutLayout, OutLayout>,
ck::Tuple<OutLayout, BiasLayout>,
OutLayout,
InDataType,
WeiDataType,
@@ -117,8 +121,8 @@ int execute_conv_fwd_scaleadd_scaleadd_relu()
in_strides,
wei_lengths,
wei_strides,
{out_lengths, out_lengths},
{out_strides, out_strides},
{out_lengths, bias_lengths},
{out_strides, bias_strides},
out_lengths,
out_strides,
filter_strides,
@@ -187,8 +191,8 @@ int execute_conv_fwd_scaleadd_scaleadd_relu()
in_strides,
wei_lengths,
wei_strides,
{out_lengths, out_lengths},
{out_strides, out_strides},
{out_lengths, bias_lengths},
{out_strides, bias_strides},
out_lengths,
out_strides,
filter_strides,