mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
Grouped convolution forward with clamp (#2334)
* Grouped convolution forward with clamp * Optimize clamp * unary fixes * test gk bias * Revert "test gk bias" This reverts commit8e42e29d7b. * Revert "Revert "test gk bias"" This reverts commite73c0550ce. * workaround comment
This commit is contained in:
@@ -208,6 +208,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-in_layout",
|
||||
"-I",
|
||||
"--in_layout",
|
||||
"--I",
|
||||
default="NCHW",
|
||||
type=str,
|
||||
required=False,
|
||||
@@ -216,6 +218,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-forw",
|
||||
"-F",
|
||||
"--forw",
|
||||
"--F",
|
||||
default=0,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -231,6 +235,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-spatial_dim",
|
||||
"-_",
|
||||
"--spatial_dim",
|
||||
"--_",
|
||||
default=2,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -239,6 +245,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-batchsize",
|
||||
"-n",
|
||||
"--batchsize",
|
||||
"--n",
|
||||
default=100,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -247,6 +255,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-in_channels",
|
||||
"-c",
|
||||
"--in_channels",
|
||||
"--c",
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -255,6 +265,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-in_d",
|
||||
"-!",
|
||||
"--in_d",
|
||||
"--!",
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -263,6 +275,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-in_h",
|
||||
"-H",
|
||||
"--in_h",
|
||||
"--H",
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -271,6 +285,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-in_w",
|
||||
"-W",
|
||||
"--in_w",
|
||||
"--W",
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -279,6 +295,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-out_channels",
|
||||
"-k",
|
||||
"--out_channels",
|
||||
"--k",
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -287,6 +305,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-fil_d",
|
||||
"-@",
|
||||
"--fil_d",
|
||||
"--@",
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -295,6 +315,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-fil_h",
|
||||
"-y",
|
||||
"--fil_h",
|
||||
"--y",
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -303,6 +325,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-fil_w",
|
||||
"-x",
|
||||
"--fil_w",
|
||||
"--x",
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -311,6 +335,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-conv_stride_d",
|
||||
"-#",
|
||||
"--conv_stride_d",
|
||||
"--#",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -319,6 +345,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-conv_stride_h",
|
||||
"-u",
|
||||
"--conv_stride_h",
|
||||
"--u",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -327,6 +355,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-conv_stride_w",
|
||||
"-v",
|
||||
"--conv_stride_w",
|
||||
"--v",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -335,6 +365,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-pad_d",
|
||||
"-$",
|
||||
"--pad_d",
|
||||
"--$",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -343,6 +375,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-pad_h",
|
||||
"-p",
|
||||
"--pad_h",
|
||||
"--p",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -351,6 +385,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-pad_w",
|
||||
"-q",
|
||||
"--pad_w",
|
||||
"--q",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -359,6 +395,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-verify",
|
||||
"-V",
|
||||
"--verify",
|
||||
"--V",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -367,6 +405,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-time",
|
||||
"-t",
|
||||
"--time",
|
||||
"--t",
|
||||
default=0,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -375,6 +415,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-dilation_d",
|
||||
"-^",
|
||||
"--dilation_d",
|
||||
"--^",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -383,6 +425,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-dilation_h",
|
||||
"-l",
|
||||
"--dilation_h",
|
||||
"--l",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -391,6 +435,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-dilation_w",
|
||||
"-j",
|
||||
"--dilation_w",
|
||||
"--j",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
@@ -399,6 +445,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-group_count",
|
||||
"-g",
|
||||
"--group_count",
|
||||
"--g",
|
||||
type=int,
|
||||
default=1,
|
||||
required=False,
|
||||
|
||||
Reference in New Issue
Block a user