mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Grouped Conv Bwd Data out index calculation optimizations (#2917)
* Grouped Conv Bwd Data index calculation optimizations * fixes * refactor instances * gfx12 fixes * temporary disable splitK for gfx12
This commit is contained in:
@@ -10,7 +10,7 @@ import subprocess
|
||||
|
||||
|
||||
def init_const_args(args):
|
||||
args.ck_profiler_cmd = '../build/bin/ckProfiler'
|
||||
args.ck_profiler_cmd = "../build/bin/ckProfiler"
|
||||
# use decimal values
|
||||
args.init_method = 2
|
||||
# don't print tensor values
|
||||
@@ -27,52 +27,62 @@ def run_ck_profiler_cmd(cmd):
|
||||
|
||||
|
||||
def parse_layouts(args):
|
||||
if args.in_layout == "NCW" or args.in_layout == "NCHW" or \
|
||||
args.in_layout == "NCDHW":
|
||||
if args.in_layout == "NCW" or args.in_layout == "NCHW" or args.in_layout == "NCDHW":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight":
|
||||
args.layout = 4
|
||||
elif args.ck_profier_op == "grouped_conv_fwd" or \
|
||||
args.ck_profier_op == "grouped_conv_bwd_data":
|
||||
elif (
|
||||
args.ck_profier_op == "grouped_conv_fwd"
|
||||
or args.ck_profier_op == "grouped_conv_bwd_data"
|
||||
):
|
||||
args.layout = 3
|
||||
else:
|
||||
print('Not supported layout for this op')
|
||||
print("Not supported layout for this op")
|
||||
exit(1)
|
||||
elif args.in_layout == "NWC" or args.in_layout == "NHWC" or \
|
||||
args.in_layout == "NDHWC":
|
||||
elif (
|
||||
args.in_layout == "NWC" or args.in_layout == "NHWC" or args.in_layout == "NDHWC"
|
||||
):
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight":
|
||||
args.layout = 2
|
||||
elif args.ck_profier_op == "grouped_conv_bwd_data" or \
|
||||
args.ck_profier_op == "grouped_conv_fwd":
|
||||
elif (
|
||||
args.ck_profier_op == "grouped_conv_bwd_data"
|
||||
or args.ck_profier_op == "grouped_conv_fwd"
|
||||
):
|
||||
args.layout = 1
|
||||
else:
|
||||
print('Not supported layout for this op')
|
||||
print("Not supported layout for this op")
|
||||
exit(1)
|
||||
|
||||
|
||||
def parse_data_type(args):
|
||||
if args.data_type == "fp32":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
|
||||
args.ck_profier_op == "grouped_conv_bwd_data" or \
|
||||
args.ck_profier_op == "grouped_conv_fwd":
|
||||
if (
|
||||
args.ck_profier_op == "grouped_conv_bwd_weight"
|
||||
or args.ck_profier_op == "grouped_conv_bwd_data"
|
||||
or args.ck_profier_op == "grouped_conv_fwd"
|
||||
):
|
||||
args.data_type = 0
|
||||
if args.data_type == "fp16":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
|
||||
args.ck_profier_op == "grouped_conv_bwd_data" or \
|
||||
args.ck_profier_op == "grouped_conv_fwd":
|
||||
if (
|
||||
args.ck_profier_op == "grouped_conv_bwd_weight"
|
||||
or args.ck_profier_op == "grouped_conv_bwd_data"
|
||||
or args.ck_profier_op == "grouped_conv_fwd"
|
||||
):
|
||||
args.data_type = 1
|
||||
if args.data_type == "int8":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight":
|
||||
args.data_type = 4
|
||||
if args.ck_profier_op == "grouped_conv_bwd_data":
|
||||
print('Not supported data type for grouped_conv_bwd_data')
|
||||
print("Not supported data type for grouped_conv_bwd_data")
|
||||
exit(1)
|
||||
if args.ck_profier_op == "grouped_conv_fwd":
|
||||
args.data_type = 3
|
||||
if args.data_type == "bfp16":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight":
|
||||
args.data_type = 5
|
||||
if args.ck_profier_op == "grouped_conv_bwd_data" or \
|
||||
args.ck_profier_op == "grouped_conv_fwd":
|
||||
if (
|
||||
args.ck_profier_op == "grouped_conv_bwd_data"
|
||||
or args.ck_profier_op == "grouped_conv_fwd"
|
||||
):
|
||||
args.data_type = 2
|
||||
|
||||
|
||||
@@ -93,13 +103,11 @@ def add_conv_params_to_cmd(args, cmd):
|
||||
cmd += [str(args.in_d), str(args.in_h), str(args.in_w)]
|
||||
cmd += [str(args.conv_stride_d), str(args.conv_stride_h)]
|
||||
cmd += [str(args.conv_stride_w)]
|
||||
cmd += [str(args.dilation_d),
|
||||
str(args.dilation_h),
|
||||
str(args.dilation_w)]
|
||||
cmd += [str(args.dilation_d), str(args.dilation_h), str(args.dilation_w)]
|
||||
cmd += [str(args.pad_d), str(args.pad_h), str(args.pad_w)]
|
||||
cmd += [str(args.pad_d), str(args.pad_h), str(args.pad_w)]
|
||||
else:
|
||||
print('Not supported spatial dim (supported: 1, 2, 3)')
|
||||
print("Not supported spatial dim (supported: 1, 2, 3)")
|
||||
exit(1)
|
||||
|
||||
|
||||
@@ -147,7 +155,7 @@ def run_ck_grouped_conv_bwd_weight(args):
|
||||
parse_data_type(args)
|
||||
parse_layouts(args)
|
||||
# Test all split K value from the list {1, 2, 4, 8, 32, 64, 128}
|
||||
args.split_k_value = -1
|
||||
args.split_k_value = "all"
|
||||
|
||||
cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)]
|
||||
cmd += [str(args.data_type), str(args.layout)]
|
||||
@@ -161,23 +169,23 @@ def run_ck_grouped_conv_bwd_weight(args):
|
||||
cmd += [str(args.split_k_value)]
|
||||
run_ck_profiler_cmd(cmd)
|
||||
|
||||
|
||||
# Get name of miopen driver, remove it from unknown
|
||||
def process_miopen_driver_name(args, unknown):
|
||||
if "convint8" in unknown:
|
||||
args.data_type = 'int8'
|
||||
args.data_type = "int8"
|
||||
unknown.remove("convint8")
|
||||
elif "convbfp16" in unknown:
|
||||
args.data_type = 'bfp16'
|
||||
args.data_type = "bfp16"
|
||||
unknown.remove("convbfp16")
|
||||
elif "convfp16" in unknown:
|
||||
args.data_type = 'fp16'
|
||||
args.data_type = "fp16"
|
||||
unknown.remove("convfp16")
|
||||
elif "conv" in unknown:
|
||||
args.data_type = 'fp32'
|
||||
args.data_type = "fp32"
|
||||
unknown.remove("conv")
|
||||
else:
|
||||
print('Not supported driver (supported: conv, convfp16, convint8,'
|
||||
' convbfp16).')
|
||||
print("Not supported driver (supported: conv, convfp16, convint8, convbfp16).")
|
||||
exit(1)
|
||||
|
||||
|
||||
@@ -199,11 +207,11 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="converter",
|
||||
description="Convert miopen driver command to ck Profiler"
|
||||
"\nExample: python3 "
|
||||
"../script/convert_miopen_driver_to_profiler.py "
|
||||
"/opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 "
|
||||
"-k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g "
|
||||
"32 -F 1 -t 1",
|
||||
"\nExample: python3 "
|
||||
"../script/convert_miopen_driver_to_profiler.py "
|
||||
"/opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 "
|
||||
"-k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g "
|
||||
"32 -F 1 -t 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-in_layout",
|
||||
@@ -213,7 +221,7 @@ if __name__ == "__main__":
|
||||
default="NCHW",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)"
|
||||
help="Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-forw",
|
||||
@@ -230,7 +238,7 @@ if __name__ == "__main__":
|
||||
"\n4 wrw only"
|
||||
"\n3 fwd+bwd"
|
||||
"\n5 fwd+wrw"
|
||||
"\n6 bwd+wrw"
|
||||
"\n6 bwd+wrw",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-spatial_dim",
|
||||
@@ -240,7 +248,7 @@ if __name__ == "__main__":
|
||||
default=2,
|
||||
type=int,
|
||||
required=False,
|
||||
help="convolution spatial dimension (Default-2)"
|
||||
help="convolution spatial dimension (Default-2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-batchsize",
|
||||
@@ -250,7 +258,7 @@ if __name__ == "__main__":
|
||||
default=100,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Mini-batch size (Default=100)"
|
||||
help="Mini-batch size (Default=100)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-in_channels",
|
||||
@@ -260,7 +268,7 @@ if __name__ == "__main__":
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Number of Input Channels (Default=3)"
|
||||
help="Number of Input Channels (Default=3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-in_d",
|
||||
@@ -270,7 +278,7 @@ if __name__ == "__main__":
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Input Depth (Default=32)"
|
||||
help="Input Depth (Default=32)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-in_h",
|
||||
@@ -280,7 +288,7 @@ if __name__ == "__main__":
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Input Height (Default=32)"
|
||||
help="Input Height (Default=32)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-in_w",
|
||||
@@ -290,7 +298,7 @@ if __name__ == "__main__":
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Input Width (Default=32)"
|
||||
help="Input Width (Default=32)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-out_channels",
|
||||
@@ -300,7 +308,7 @@ if __name__ == "__main__":
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Number of Output Channels (Default=32)"
|
||||
help="Number of Output Channels (Default=32)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-fil_d",
|
||||
@@ -310,7 +318,7 @@ if __name__ == "__main__":
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Filter Depth (Default=3)"
|
||||
help="Filter Depth (Default=3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-fil_h",
|
||||
@@ -320,7 +328,7 @@ if __name__ == "__main__":
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Filter Height (Default=3)"
|
||||
help="Filter Height (Default=3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-fil_w",
|
||||
@@ -330,7 +338,7 @@ if __name__ == "__main__":
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Filter Width (Default=3)"
|
||||
help="Filter Width (Default=3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-conv_stride_d",
|
||||
@@ -340,7 +348,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Convolution Stride for Depth (Default=1)"
|
||||
help="Convolution Stride for Depth (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-conv_stride_h",
|
||||
@@ -350,7 +358,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Convolution Stride for Height (Default=1)"
|
||||
help="Convolution Stride for Height (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-conv_stride_w",
|
||||
@@ -360,7 +368,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Convolution Stride for Width (Default=1)"
|
||||
help="Convolution Stride for Width (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pad_d",
|
||||
@@ -370,7 +378,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Zero Padding for Depth (Default=0)"
|
||||
help="Zero Padding for Depth (Default=0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pad_h",
|
||||
@@ -380,7 +388,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Zero Padding for Height (Default=0)"
|
||||
help="Zero Padding for Height (Default=0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pad_w",
|
||||
@@ -390,7 +398,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Zero Padding for Width (Default=0)"
|
||||
help="Zero Padding for Width (Default=0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-verify",
|
||||
@@ -400,7 +408,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Verify Each Layer (Default=1)"
|
||||
help="Verify Each Layer (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-time",
|
||||
@@ -410,7 +418,7 @@ if __name__ == "__main__":
|
||||
default=0,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Time Each Layer (Default=0)"
|
||||
help="Time Each Layer (Default=0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-dilation_d",
|
||||
@@ -420,7 +428,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Dilation of Filter Depth (Default=1)"
|
||||
help="Dilation of Filter Depth (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-dilation_h",
|
||||
@@ -430,7 +438,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Dilation of Filter Height (Default=1)"
|
||||
help="Dilation of Filter Height (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-dilation_w",
|
||||
@@ -440,7 +448,7 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Dilation of Filter Width (Default=1)"
|
||||
help="Dilation of Filter Width (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-group_count",
|
||||
@@ -450,7 +458,7 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
default=1,
|
||||
required=False,
|
||||
help="Number of Groups (Default=1)"
|
||||
help="Number of Groups (Default=1)",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
Reference in New Issue
Block a user