diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index e324f85ed8..4e8d7732d5 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,12 +1,12 @@ # generate a list of kernels, but not actually emit files at config stage execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --direction fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt + --apis fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt ) execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --direction bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt + --apis bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt ) # NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory @@ -17,13 +17,13 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) add_custom_command( OUTPUT ${FMHA_FWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --direction fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --apis fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} ) add_custom_command( OUTPUT ${FMHA_BWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --direction bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --apis bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} ) set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") diff --git a/example/ck_tile/01_fmha/codegen/cmake_config.py b/example/ck_tile/01_fmha/codegen/cmake_config.py index f6943ceda6..03ebfd6702 100644 --- a/example/ck_tile/01_fmha/codegen/cmake_config.py +++ b/example/ck_tile/01_fmha/codegen/cmake_config.py @@ -1 +1,5 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + GEN_DIR = "" # in Cmake, have to generate files in same folder \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 0ea4e74f55..d3d215f7f5 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -1,3 +1,7 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + DTYPE_MAP = { "fp16": "ck_tile::fp16_t", "bf16": "ck_tile::bf16_t", diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index bdd4aa17f3..0160915a54 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -1,3 +1,7 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + import copy from dataclasses import dataclass import fnmatch diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 781e06e95b..1486671f6b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1,3 +1,7 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + import copy from dataclasses import dataclass import fnmatch diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 17eac0a87b..893fab2b24 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -4,7 +4,7 @@ import argparse from pathlib import Path -from typing import Optional +from typing import List, Optional from codegen.cmake_config import * from codegen.ops import ( @@ -12,40 +12,45 @@ from codegen.ops import ( fmha_bwd ) -def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: output_dir = Path(output_dir) / GEN_DIR output_dir.mkdir(parents=True, exist_ok=True) - if direction == 'fwd': - fmha_fwd.write_blobs(output_dir, kernel_filter, receipt, mask_impl) - else: - fmha_bwd.write_blobs(output_dir, kernel_filter, receipt, mask_impl) + + write_blobs_iml = { + 'fwd': fmha_fwd.write_blobs, + 'bwd': fmha_bwd.write_blobs, + } + for api in api_list: + write_blobs_iml[api](output_dir, kernel_filter, receipt, mask_impl) # list all the files that will be generated -def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) - if direction == 'fwd': - fmha_fwd.list_blobs(file_path, kernel_filter, receipt, mask_impl) - else: - fmha_bwd.list_blobs(file_path, kernel_filter, receipt, mask_impl) + list_blobs_iml = { + 'fwd': fmha_fwd.list_blobs, + 'bwd': fmha_bwd.list_blobs, + } + for api in api_list: + list_blobs_iml[api](file_path, kernel_filter, receipt, mask_impl) if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate", - description="gen api for CK fmha kernel", + description="gen API for CK fmha kernel", ) parser.add_argument( - "-d", - "--direction", + "-a", + "--apis", default='fwd', choices=['fwd', 'bwd'], required=False, - help="choose the direction of kernels(default: fwd)" + help="supply API(s) to generate (default: fwd). separated by comma." ) parser.add_argument( "-o", @@ -86,7 +91,8 @@ if __name__ == "__main__": ) args = parser.parse_args() + api_list = args.apis.split(',') if args.list_blobs is not None: - list_blobs(args.list_blobs, args.direction, args.filter, int(args.receipt), mask_impl=args.mask) + list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask) else: - write_blobs(args.output_dir, args.direction, args.filter, int(args.receipt), mask_impl=args.mask) + write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask)