Files
jiangyon.ren 4d2f8c111e [CK_TILE][FMHA] Add sparse attention VSA (#3341)
* add sparse attention VSA

* fix the pre-commit

* Add jenga test and pre-commit

* add bf16 for vsa

* add jenga support bf16

* remove lse arg

* split kernel code to block & kernel

* fix the pre-commit

* fix the pre-commit

* fix the copyrights

* fix the copyright

* fix the copyright & rename block to pipeline

* fix the copyright and pipeline

* remove lse & dropout & add fmt

* fix the jenga&VSA code review

* remove the useless code & resolved the comments

* remove useless code

* remove useless code

* Clean up code

* Remove more unused code

* Re-format .hpp

* Refactor codegen scripts

---------

Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
2026-01-31 00:59:47 +08:00

167 lines
4.8 KiB
Python

# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import argparse
from enum import IntEnum
from pathlib import Path
import pkgutil
from typing import List, Optional
import codegen.ops
class HandlerId(IntEnum):
LIST_BLOBS = 0
WRITE_BLOBS = 1
# inspect all modules under 'codegen.ops' and register API handlers
ops = []
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
full_module_name = "%s.%s" % (codegen.ops.__name__, module_name)
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
unwanted_prefix = "fmha_"
handlers = dict(
[
(
op.__name__[len(unwanted_prefix) :]
if op.__name__.startswith(unwanted_prefix)
else op.__name__,
(op.list_blobs, op.write_blobs),
)
for op in ops
]
)
assert 0 < len(handlers)
def write_blobs(
output_dir: Optional[str],
api_list: List[str],
filters_list: List[str],
optdim_list: List[int],
receipt,
mask_impl,
) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.WRITE_BLOBS]
handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl)
# list all the files that will be generated
def list_blobs(
output_file: Optional[str],
api_list: List[str],
filters_list: List[str],
optdim_list: List[int],
receipt,
mask_impl,
) -> None:
assert output_file is not None
file_path = Path(output_file)
# create an empty file / drop its contents if it exists
open(file_path, "w").close()
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, optdim_list, mask_impl)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
description="gen API for CK fmha kernel",
)
parser.add_argument(
"-d",
"--direction", # we keep 'direction' option for backward compatibility
"-a",
"--api",
default="fwd_jenga",
required=False,
help="supply API(s) to generate (default: fwd). separated by comma.",
)
parser.add_argument(
"-o",
"--output_dir",
required=False,
help="write all the blobs into a directory",
)
parser.add_argument(
"-l", "--list_blobs", required=False, help="list all the kernels to a file"
)
# TODO: if using filter, must apply same value to output_dir and list_blobs
parser.add_argument(
"-f",
"--filter",
default="",
required=False,
help="filter out kernels that need to generate, using fnmatch module",
)
parser.add_argument(
"-m",
"--mask",
default="simplified",
required=False,
help="mask implementation, simplified/generic",
)
parser.add_argument(
"-r",
"--receipt",
default=0,
required=False,
help="codegen receipt. 0: generate only 8xhdim coverage\n"
+ " 1: generate more instance to cover all hdim\n"
+ " 2: Only generate instance for Flash attention integration\n"
+ " 4: Only generate instance for PyTorch integration\n"
+ " 100-199: Only generate instance for Aiter(mha_fwd) integration\n"
+ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n"
+ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n"
+ " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n"
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration",
)
parser.add_argument(
"--optdim",
default="-1",
required=False,
help="only optimize the hdim in the list. separated by comma. -1 is the default choice"
+ "eg. --optdim=32,64,128,256",
)
args = parser.parse_args()
api_list = args.direction.split(",")
filter_list = args.filter.split(",")
filter_list.extend([""] * (len(api_list) - len(filter_list)))
optdim_list = [int(hdim) for hdim in args.optdim.split(",")]
if args.list_blobs is not None:
list_blobs(
args.list_blobs,
api_list,
filter_list,
optdim_list,
int(args.receipt),
mask_impl=args.mask,
)
else:
write_blobs(
args.output_dir,
api_list,
filter_list,
optdim_list,
int(args.receipt),
mask_impl=args.mask,
)