Add support for NGCHW in grouped conv bwd wei (#1491)

* Add support for NGCHW in grouped conv bwd wei

* Comments fixes

* navi fixes

* Update function names
This commit is contained in:
Bartłomiej Kocot
2024-09-03 10:52:03 +02:00
committed by GitHub
parent a9b170b541
commit 73b67f290f
24 changed files with 893 additions and 89 deletions

View File

@@ -23,6 +23,26 @@ def run_ck_profiler_cmd(cmd):
subprocess.run(cmd)
def parse_layouts(args):
if args.in_layout == "NCW" or args.in_layout == "NCHW" or \
args.in_layout == "NCDHW":
if args.ck_profier_op == "grouped_conv_bwd_weight":
args.layout = 3
else:
print('Not supported layout for this op')
exit(1)
elif args.in_layout == "NWC" or args.in_layout == "NHWC" or \
args.in_layout == "NDHWC":
if args.ck_profier_op == "grouped_conv_bwd_weight":
args.layout = 2
elif args.ck_profier_op == "grouped_conv_bwd_data" or \
args.ck_profier_op == "grouped_conv_fwd":
args.layout = 1
else:
print('Not supported layout for this op')
exit(1)
def parse_data_type(args):
if args.data_type == "fp32":
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
@@ -79,8 +99,7 @@ def add_conv_params_to_cmd(args, cmd):
def run_ck_grouped_conv_fwd(args):
args.ck_profier_op = "grouped_conv_fwd"
parse_data_type(args)
# default for MIOpen NHWGC
args.layout = 1
parse_layouts(args)
# use int32 by default
args.index_type = 0
@@ -99,8 +118,7 @@ def run_ck_grouped_conv_fwd(args):
def run_ck_grouped_conv_bwd_data(args):
args.ck_profier_op = "grouped_conv_bwd_data"
parse_data_type(args)
# default for MIOpen NHWGC
args.layout = 1
parse_layouts(args)
cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)]
cmd += [str(args.data_type), str(args.layout)]
@@ -117,8 +135,7 @@ def run_ck_grouped_conv_bwd_data(args):
def run_ck_grouped_conv_bwd_weight(args):
args.ck_profier_op = "grouped_conv_bwd_weight"
parse_data_type(args)
# default for MIOpen NHWGC
args.layout = 2
parse_layouts(args)
# Test all split K value from the list {1, 2, 4, 8, 32, 64, 128}
args.split_k_value = -1
@@ -181,8 +198,8 @@ if __name__ == "__main__":
parser.add_argument(
"-in_layout",
"-I",
default=-1,
type=int,
default="NCHW",
type=str,
required=False,
help="Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)"
)