mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
Support generate multiple APIs for an example
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--direction fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
--apis fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--direction bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
|
||||
--apis bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
|
||||
)
|
||||
|
||||
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
|
||||
@@ -17,13 +17,13 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--direction fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--apis fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_BWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--direction bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--apis bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
)
|
||||
|
||||
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
|
||||
|
||||
@@ -1 +1,5 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
GEN_DIR = "" # in Cmake, have to generate files in same folder
|
||||
@@ -1,3 +1,7 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
DTYPE_MAP = {
|
||||
"fp16": "ck_tile::fp16_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import fnmatch
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import fnmatch
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.ops import (
|
||||
@@ -12,40 +12,45 @@ from codegen.ops import (
|
||||
fmha_bwd
|
||||
)
|
||||
|
||||
def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
if output_dir is None:
|
||||
output_dir = Path(__file__).parent
|
||||
else:
|
||||
output_dir = Path(output_dir) / GEN_DIR
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
if direction == 'fwd':
|
||||
fmha_fwd.write_blobs(output_dir, kernel_filter, receipt, mask_impl)
|
||||
else:
|
||||
fmha_bwd.write_blobs(output_dir, kernel_filter, receipt, mask_impl)
|
||||
|
||||
write_blobs_iml = {
|
||||
'fwd': fmha_fwd.write_blobs,
|
||||
'bwd': fmha_bwd.write_blobs,
|
||||
}
|
||||
for api in api_list:
|
||||
write_blobs_iml[api](output_dir, kernel_filter, receipt, mask_impl)
|
||||
|
||||
# list all the files that will be generated
|
||||
def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
assert output_file is not None
|
||||
file_path = Path(output_file)
|
||||
|
||||
if direction == 'fwd':
|
||||
fmha_fwd.list_blobs(file_path, kernel_filter, receipt, mask_impl)
|
||||
else:
|
||||
fmha_bwd.list_blobs(file_path, kernel_filter, receipt, mask_impl)
|
||||
list_blobs_iml = {
|
||||
'fwd': fmha_fwd.list_blobs,
|
||||
'bwd': fmha_bwd.list_blobs,
|
||||
}
|
||||
for api in api_list:
|
||||
list_blobs_iml[api](file_path, kernel_filter, receipt, mask_impl)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="gen api for CK fmha kernel",
|
||||
description="gen API for CK fmha kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--direction",
|
||||
"-a",
|
||||
"--apis",
|
||||
default='fwd',
|
||||
choices=['fwd', 'bwd'],
|
||||
required=False,
|
||||
help="choose the direction of kernels(default: fwd)"
|
||||
help="supply API(s) to generate (default: fwd). separated by comma."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
@@ -86,7 +91,8 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
api_list = args.apis.split(',')
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(args.list_blobs, args.direction, args.filter, int(args.receipt), mask_impl=args.mask)
|
||||
list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
|
||||
else:
|
||||
write_blobs(args.output_dir, args.direction, args.filter, int(args.receipt), mask_impl=args.mask)
|
||||
write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
|
||||
|
||||
Reference in New Issue
Block a user