mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Let 'api' an alias of 'direction' option
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--apis fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
--api fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--apis bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
|
||||
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
|
||||
)
|
||||
|
||||
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
|
||||
@@ -17,13 +17,13 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--apis fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--api fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_BWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--apis bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
)
|
||||
|
||||
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
|
||||
|
||||
@@ -45,8 +45,10 @@ if __name__ == "__main__":
|
||||
description="gen API for CK fmha kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--direction", # we keep 'direction' option for backward compatibility
|
||||
"-a",
|
||||
"--apis",
|
||||
"--api",
|
||||
default='fwd',
|
||||
choices=['fwd', 'bwd'],
|
||||
required=False,
|
||||
@@ -91,7 +93,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
api_list = args.apis.split(',')
|
||||
api_list = args.direction.split(',')
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user