mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
* uncomment all the headdim, use optdim to control * change default back to -1 * uncomment splitkv instance * Fix typo in receipt 4 for appendkv * support optdim for bwd, splitkv and appendkv * Fix 192 key error --------- Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Co-authored-by: Andy Lugo <Andy.LugoReyes@amd.com>
133 lines
4.7 KiB
Python
133 lines
4.7 KiB
Python
# SPDX-License-Identifier: MIT
|
|
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
|
# generate kernel instances to speed up compilation
|
|
|
|
import argparse
|
|
from enum import IntEnum
|
|
from pathlib import Path
|
|
import pkgutil
|
|
import sys
|
|
from typing import List, Optional
|
|
|
|
import codegen.ops
|
|
from codegen.cmake_config import *
|
|
|
|
|
|
class HandlerId(IntEnum):
|
|
LIST_BLOBS = 0
|
|
WRITE_BLOBS = 1
|
|
|
|
# inspect all modules under 'codegen.ops' and register API handlers
|
|
ops = []
|
|
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
|
|
full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
|
|
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
|
|
unwanted_prefix = 'fmha_'
|
|
handlers = dict(
|
|
[(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__,
|
|
(op.list_blobs, op.write_blobs)) for op in ops]
|
|
)
|
|
assert 0 < len(handlers)
|
|
|
|
def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None:
|
|
if output_dir is None:
|
|
output_dir = Path(__file__).parent
|
|
else:
|
|
output_dir = Path(output_dir) / GEN_DIR
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
for api, kernel_filter in zip(api_list, filters_list):
|
|
handler = handlers[api][HandlerId.WRITE_BLOBS]
|
|
handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl)
|
|
|
|
# list all the files that will be generated
|
|
def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None:
|
|
assert output_file is not None
|
|
file_path = Path(output_file)
|
|
|
|
# create an empty file / drop its contents if it exists
|
|
open(file_path, "w").close()
|
|
|
|
for api, kernel_filter in zip(api_list, filters_list):
|
|
handler = handlers[api][HandlerId.LIST_BLOBS]
|
|
handler(file_path, kernel_filter, receipt, optdim_list, mask_impl)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
prog="generate",
|
|
description="gen API for CK fmha kernel",
|
|
)
|
|
parser.add_argument(
|
|
"-d",
|
|
"--direction", # we keep 'direction' option for backward compatibility
|
|
"-a",
|
|
"--api",
|
|
default='fwd',
|
|
required=False,
|
|
help="supply API(s) to generate (default: fwd). separated by comma."
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--output_dir",
|
|
required=False,
|
|
help="write all the blobs into a directory"
|
|
)
|
|
parser.add_argument(
|
|
"-l",
|
|
"--list_blobs",
|
|
required=False,
|
|
help="list all the kernels to a file"
|
|
)
|
|
# TODO: if using filter, must apply same value to output_dir and list_blobs
|
|
parser.add_argument(
|
|
"-f",
|
|
"--filter",
|
|
default='',
|
|
required=False,
|
|
help="filter out kernels that need to generate, using fnmatch module"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-m",
|
|
"--mask",
|
|
default="simplified",
|
|
required=False,
|
|
help="mask implementation, simplified/generic"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-r",
|
|
"--receipt",
|
|
default=0,
|
|
required=False,
|
|
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
|
|
" 1: generate more instance to cover all hdim\n" + \
|
|
" 2: Only generate instance for Flash attention integration\n" + \
|
|
" 4: Only generate instance for PyTorch integration\n" + \
|
|
" 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \
|
|
" 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \
|
|
" 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \
|
|
" 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \
|
|
" 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--optdim",
|
|
default='-1',
|
|
required=False,
|
|
help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \
|
|
"eg. --optdim=32,64,128,256"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
api_list = args.direction.split(',')
|
|
filter_list = args.filter.split(',')
|
|
filter_list.extend([''] * (len(api_list) - len(filter_list)))
|
|
optdim_list = [int(hdim) for hdim in args.optdim.split(',')]
|
|
|
|
if args.list_blobs is not None:
|
|
list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask)
|
|
else:
|
|
write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask)
|