mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add support for NGCHW in grouped conv bwd wei (#1491)
* Add support for NGCHW in grouped conv bwd wei * Comments fixes * navi fixes * Update function names
This commit is contained in:
@@ -23,6 +23,26 @@ def run_ck_profiler_cmd(cmd):
|
||||
subprocess.run(cmd)
|
||||
|
||||
|
||||
def parse_layouts(args):
|
||||
if args.in_layout == "NCW" or args.in_layout == "NCHW" or \
|
||||
args.in_layout == "NCDHW":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight":
|
||||
args.layout = 3
|
||||
else:
|
||||
print('Not supported layout for this op')
|
||||
exit(1)
|
||||
elif args.in_layout == "NWC" or args.in_layout == "NHWC" or \
|
||||
args.in_layout == "NDHWC":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight":
|
||||
args.layout = 2
|
||||
elif args.ck_profier_op == "grouped_conv_bwd_data" or \
|
||||
args.ck_profier_op == "grouped_conv_fwd":
|
||||
args.layout = 1
|
||||
else:
|
||||
print('Not supported layout for this op')
|
||||
exit(1)
|
||||
|
||||
|
||||
def parse_data_type(args):
|
||||
if args.data_type == "fp32":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
|
||||
@@ -79,8 +99,7 @@ def add_conv_params_to_cmd(args, cmd):
|
||||
def run_ck_grouped_conv_fwd(args):
|
||||
args.ck_profier_op = "grouped_conv_fwd"
|
||||
parse_data_type(args)
|
||||
# default for MIOpen NHWGC
|
||||
args.layout = 1
|
||||
parse_layouts(args)
|
||||
# use int32 by default
|
||||
args.index_type = 0
|
||||
|
||||
@@ -99,8 +118,7 @@ def run_ck_grouped_conv_fwd(args):
|
||||
def run_ck_grouped_conv_bwd_data(args):
|
||||
args.ck_profier_op = "grouped_conv_bwd_data"
|
||||
parse_data_type(args)
|
||||
# default for MIOpen NHWGC
|
||||
args.layout = 1
|
||||
parse_layouts(args)
|
||||
|
||||
cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)]
|
||||
cmd += [str(args.data_type), str(args.layout)]
|
||||
@@ -117,8 +135,7 @@ def run_ck_grouped_conv_bwd_data(args):
|
||||
def run_ck_grouped_conv_bwd_weight(args):
|
||||
args.ck_profier_op = "grouped_conv_bwd_weight"
|
||||
parse_data_type(args)
|
||||
# default for MIOpen NHWGC
|
||||
args.layout = 2
|
||||
parse_layouts(args)
|
||||
# Test all split K value from the list {1, 2, 4, 8, 32, 64, 128}
|
||||
args.split_k_value = -1
|
||||
|
||||
@@ -181,8 +198,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-in_layout",
|
||||
"-I",
|
||||
default=-1,
|
||||
type=int,
|
||||
default="NCHW",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user