Support generate multiple APIs for an example

This commit is contained in:
PoYen, Chen
2024-06-21 10:04:17 +00:00
parent 925d25ff47
commit 51487f238a
6 changed files with 43 additions and 21 deletions

View File

@@ -1,12 +1,12 @@
# generate a list of kernels, but not actually emit files at config stage
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
--apis fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
)
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
--apis bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
)
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
@@ -17,13 +17,13 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
--apis fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
)
add_custom_command(
OUTPUT ${FMHA_BWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
--apis bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
)
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")

View File

@@ -1 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
GEN_DIR = "" # in Cmake, have to generate files in same folder

View File

@@ -1,3 +1,7 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
DTYPE_MAP = {
"fp16": "ck_tile::fp16_t",
"bf16": "ck_tile::bf16_t",

View File

@@ -1,3 +1,7 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
import fnmatch

View File

@@ -1,3 +1,7 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
import fnmatch

View File

@@ -4,7 +4,7 @@
import argparse
from pathlib import Path
from typing import Optional
from typing import List, Optional
from codegen.cmake_config import *
from codegen.ops import (
@@ -12,40 +12,45 @@ from codegen.ops import (
fmha_bwd
)
def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optional[str], receipt, mask_impl) -> None:
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
output_dir = Path(output_dir) / GEN_DIR
output_dir.mkdir(parents=True, exist_ok=True)
if direction == 'fwd':
fmha_fwd.write_blobs(output_dir, kernel_filter, receipt, mask_impl)
else:
fmha_bwd.write_blobs(output_dir, kernel_filter, receipt, mask_impl)
write_blobs_iml = {
'fwd': fmha_fwd.write_blobs,
'bwd': fmha_bwd.write_blobs,
}
for api in api_list:
write_blobs_iml[api](output_dir, kernel_filter, receipt, mask_impl)
# list all the files that will be generated
def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Optional[str], receipt, mask_impl) -> None:
def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
assert output_file is not None
file_path = Path(output_file)
if direction == 'fwd':
fmha_fwd.list_blobs(file_path, kernel_filter, receipt, mask_impl)
else:
fmha_bwd.list_blobs(file_path, kernel_filter, receipt, mask_impl)
list_blobs_iml = {
'fwd': fmha_fwd.list_blobs,
'bwd': fmha_bwd.list_blobs,
}
for api in api_list:
list_blobs_iml[api](file_path, kernel_filter, receipt, mask_impl)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
description="gen api for CK fmha kernel",
description="gen API for CK fmha kernel",
)
parser.add_argument(
"-d",
"--direction",
"-a",
"--apis",
default='fwd',
choices=['fwd', 'bwd'],
required=False,
help="choose the direction of kernels(default: fwd)"
help="supply API(s) to generate (default: fwd). separated by comma."
)
parser.add_argument(
"-o",
@@ -86,7 +91,8 @@ if __name__ == "__main__":
)
args = parser.parse_args()
api_list = args.apis.split(',')
if args.list_blobs is not None:
list_blobs(args.list_blobs, args.direction, args.filter, int(args.receipt), mask_impl=args.mask)
list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
else:
write_blobs(args.output_dir, args.direction, args.filter, int(args.receipt), mask_impl=args.mask)
write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask)