mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
Support broadcast for bias in grouped conv fwd (#1081)
* Support broadcast for bias in grouped conv fwd
* Fix comment
* Comment fixes
* Remove GK layout
[ROCm/composable_kernel commit: f836984891]
This commit is contained in:
@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, NDHWGK>,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
|
||||
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, NDHWGK>,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
add_device_operation_instances(
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
|
||||
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, NDHWGK>,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0>{});
|
||||
add_device_operation_instances(
|
||||
@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
|
||||
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, NDHWGK>,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, NDHWGK>,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
|
||||
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, NDHWGK>,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
add_device_operation_instances(
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
|
||||
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, NDHWGK>,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0>{});
|
||||
add_device_operation_instances(
|
||||
@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
|
||||
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, NDHWGK>,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, NDHWGK>,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
|
||||
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, NDHWGK>,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
add_device_operation_instances(
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
|
||||
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, NDHWGK>,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0>{});
|
||||
add_device_operation_instances(
|
||||
@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
|
||||
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, NDHWGK>,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, NDHWGK>,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
@@ -27,7 +27,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
|
||||
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, NDHWGK>,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
add_device_operation_instances(
|
||||
@@ -35,7 +35,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
|
||||
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, NDHWGK>,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0>{});
|
||||
add_device_operation_instances(
|
||||
@@ -43,7 +43,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
|
||||
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
ck::Tuple<NDHWGK, NDHWGK>,
|
||||
ck::Tuple<NDHWGK, G_K>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
}
|
||||
|
||||
@@ -22,13 +22,13 @@ using S = ck::Sequence<Is...>;
|
||||
using NHWGC = ck::tensor_layout::convolution::NHWGC;
|
||||
using GKYXC = ck::tensor_layout::convolution::GKYXC;
|
||||
using NHWGK = ck::tensor_layout::convolution::NHWGK;
|
||||
using GK = ck::tensor_layout::convolution::G_K;
|
||||
using G_K = ck::tensor_layout::convolution::G_K;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Relu = ck::tensor_operation::element_wise::Relu;
|
||||
using TanH = ck::tensor_operation::element_wise::TanH;
|
||||
|
||||
using GK_Tuple = ck::Tuple<GK>;
|
||||
using GK_GK_Tuple = ck::Tuple<GK, GK>;
|
||||
using GK_Tuple = ck::Tuple<G_K>;
|
||||
using GK_GK_Tuple = ck::Tuple<G_K, G_K>;
|
||||
using I32_Tuple = ck::Tuple<int32_t>;
|
||||
using F32_Tuple = ck::Tuple<float>;
|
||||
using I32_F32_Tuple = ck::Tuple<int32_t, float>;
|
||||
|
||||
Reference in New Issue
Block a user