Grouped convolution forward with clamp (#2334)

* Grouped convolution forward with clamp

* Optimize clamp

* unary fixes

* test gk bias

* Revert "test gk bias"

This reverts commit 8e42e29d7b.

* Revert "Revert "test gk bias""

This reverts commit e73c0550ce.

* workaround comment
This commit is contained in:
Bartłomiej Kocot
2025-06-16 15:36:53 +02:00
committed by GitHub
parent d996bc78be
commit f6c2ff9dce
41 changed files with 2103 additions and 106 deletions

View File

@@ -208,6 +208,8 @@ if __name__ == "__main__":
parser.add_argument(
"-in_layout",
"-I",
"--in_layout",
"--I",
default="NCHW",
type=str,
required=False,
@@ -216,6 +218,8 @@ if __name__ == "__main__":
parser.add_argument(
"-forw",
"-F",
"--forw",
"--F",
default=0,
type=int,
required=False,
@@ -231,6 +235,8 @@ if __name__ == "__main__":
parser.add_argument(
"-spatial_dim",
"-_",
"--spatial_dim",
"--_",
default=2,
type=int,
required=False,
@@ -239,6 +245,8 @@ if __name__ == "__main__":
parser.add_argument(
"-batchsize",
"-n",
"--batchsize",
"--n",
default=100,
type=int,
required=False,
@@ -247,6 +255,8 @@ if __name__ == "__main__":
parser.add_argument(
"-in_channels",
"-c",
"--in_channels",
"--c",
default=3,
type=int,
required=False,
@@ -255,6 +265,8 @@ if __name__ == "__main__":
parser.add_argument(
"-in_d",
"-!",
"--in_d",
"--!",
default=32,
type=int,
required=False,
@@ -263,6 +275,8 @@ if __name__ == "__main__":
parser.add_argument(
"-in_h",
"-H",
"--in_h",
"--H",
default=32,
type=int,
required=False,
@@ -271,6 +285,8 @@ if __name__ == "__main__":
parser.add_argument(
"-in_w",
"-W",
"--in_w",
"--W",
default=32,
type=int,
required=False,
@@ -279,6 +295,8 @@ if __name__ == "__main__":
parser.add_argument(
"-out_channels",
"-k",
"--out_channels",
"--k",
default=32,
type=int,
required=False,
@@ -287,6 +305,8 @@ if __name__ == "__main__":
parser.add_argument(
"-fil_d",
"-@",
"--fil_d",
"--@",
default=3,
type=int,
required=False,
@@ -295,6 +315,8 @@ if __name__ == "__main__":
parser.add_argument(
"-fil_h",
"-y",
"--fil_h",
"--y",
default=3,
type=int,
required=False,
@@ -303,6 +325,8 @@ if __name__ == "__main__":
parser.add_argument(
"-fil_w",
"-x",
"--fil_w",
"--x",
default=3,
type=int,
required=False,
@@ -311,6 +335,8 @@ if __name__ == "__main__":
parser.add_argument(
"-conv_stride_d",
"-#",
"--conv_stride_d",
"--#",
default=1,
type=int,
required=False,
@@ -319,6 +345,8 @@ if __name__ == "__main__":
parser.add_argument(
"-conv_stride_h",
"-u",
"--conv_stride_h",
"--u",
default=1,
type=int,
required=False,
@@ -327,6 +355,8 @@ if __name__ == "__main__":
parser.add_argument(
"-conv_stride_w",
"-v",
"--conv_stride_w",
"--v",
default=1,
type=int,
required=False,
@@ -335,6 +365,8 @@ if __name__ == "__main__":
parser.add_argument(
"-pad_d",
"-$",
"--pad_d",
"--$",
default=1,
type=int,
required=False,
@@ -343,6 +375,8 @@ if __name__ == "__main__":
parser.add_argument(
"-pad_h",
"-p",
"--pad_h",
"--p",
default=1,
type=int,
required=False,
@@ -351,6 +385,8 @@ if __name__ == "__main__":
parser.add_argument(
"-pad_w",
"-q",
"--pad_w",
"--q",
default=1,
type=int,
required=False,
@@ -359,6 +395,8 @@ if __name__ == "__main__":
parser.add_argument(
"-verify",
"-V",
"--verify",
"--V",
default=1,
type=int,
required=False,
@@ -367,6 +405,8 @@ if __name__ == "__main__":
parser.add_argument(
"-time",
"-t",
"--time",
"--t",
default=0,
type=int,
required=False,
@@ -375,6 +415,8 @@ if __name__ == "__main__":
parser.add_argument(
"-dilation_d",
"-^",
"--dilation_d",
"--^",
default=1,
type=int,
required=False,
@@ -383,6 +425,8 @@ if __name__ == "__main__":
parser.add_argument(
"-dilation_h",
"-l",
"--dilation_h",
"--l",
default=1,
type=int,
required=False,
@@ -391,6 +435,8 @@ if __name__ == "__main__":
parser.add_argument(
"-dilation_w",
"-j",
"--dilation_w",
"--j",
default=1,
type=int,
required=False,
@@ -399,6 +445,8 @@ if __name__ == "__main__":
parser.add_argument(
"-group_count",
"-g",
"--group_count",
"--g",
type=int,
default=1,
required=False,