Apply filter to every kernel in the codgen of FMHA (#1911)

* add receipt for fwd

* Add receipt for bwd

* Use kernel name to avoid more receipt

* apply filter to every kernel
This commit is contained in:
rocking
2025-02-26 20:20:29 +08:00
committed by GitHub
parent c9bcfd755e
commit e9ee568681
5 changed files with 126 additions and 66 deletions

View File

@@ -30,7 +30,7 @@ handlers = dict(
)
assert 0 < len(handlers)
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
@@ -38,19 +38,19 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter :
output_dir.mkdir(parents=True, exist_ok=True)
for api in api_list:
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.WRITE_BLOBS]
handler(output_dir, kernel_filter, receipt, mask_impl)
# list all the files that will be generated
def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None:
assert output_file is not None
file_path = Path(output_file)
# create an empty file / drop its contents if it exists
open(file_path, "w").close()
for api in api_list:
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, mask_impl)
@@ -84,6 +84,7 @@ if __name__ == "__main__":
parser.add_argument(
"-f",
"--filter",
default='',
required=False,
help="filter out kernels that need to generate, using fnmatch module"
)
@@ -105,15 +106,19 @@ if __name__ == "__main__":
" 1: generate more instance to cover all hdim\n" + \
" 2: Only generate instance for Flash attention integration\n" + \
" 4: Only generate instance for PyTorch integration\n" + \
" 10: Only generate instance for Aiter(mha_fwd, mha_bwd) integration\n" + \
" 11: Only generate instance for Aiter(mha_varlen_fwd, mha_varlen_bwd) integration\n" + \
" 12: Only generate instance for Aiter(mha_fwd_kvcache) integration"
" 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \
" 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \
" 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \
" 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration"
)
args = parser.parse_args()
api_list = args.direction.split(',')
filter_list = args.filter.split(',')
filter_list.extend([''] * (len(api_list) - len(filter_list)))
if args.list_blobs is not None:
list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
list_blobs(args.list_blobs, api_list, filter_list, int(args.receipt), mask_impl=args.mask)
else:
write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
write_blobs(args.output_dir, api_list, filter_list, int(args.receipt), mask_impl=args.mask)