mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
Apply filter to every kernel in the codgen of FMHA (#1911)
* add receipt for fwd * Add receipt for bwd * Use kernel name to avoid more receipt * apply filter to every kernel
This commit is contained in:
@@ -30,7 +30,7 @@ handlers = dict(
|
||||
)
|
||||
assert 0 < len(handlers)
|
||||
|
||||
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None:
|
||||
if output_dir is None:
|
||||
output_dir = Path(__file__).parent
|
||||
else:
|
||||
@@ -38,19 +38,19 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter :
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for api in api_list:
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.WRITE_BLOBS]
|
||||
handler(output_dir, kernel_filter, receipt, mask_impl)
|
||||
|
||||
# list all the files that will be generated
|
||||
def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None:
|
||||
assert output_file is not None
|
||||
file_path = Path(output_file)
|
||||
|
||||
# create an empty file / drop its contents if it exists
|
||||
open(file_path, "w").close()
|
||||
|
||||
for api in api_list:
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.LIST_BLOBS]
|
||||
handler(file_path, kernel_filter, receipt, mask_impl)
|
||||
|
||||
@@ -84,6 +84,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--filter",
|
||||
default='',
|
||||
required=False,
|
||||
help="filter out kernels that need to generate, using fnmatch module"
|
||||
)
|
||||
@@ -105,15 +106,19 @@ if __name__ == "__main__":
|
||||
" 1: generate more instance to cover all hdim\n" + \
|
||||
" 2: Only generate instance for Flash attention integration\n" + \
|
||||
" 4: Only generate instance for PyTorch integration\n" + \
|
||||
" 10: Only generate instance for Aiter(mha_fwd, mha_bwd) integration\n" + \
|
||||
" 11: Only generate instance for Aiter(mha_varlen_fwd, mha_varlen_bwd) integration\n" + \
|
||||
" 12: Only generate instance for Aiter(mha_fwd_kvcache) integration"
|
||||
|
||||
" 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \
|
||||
" 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \
|
||||
" 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \
|
||||
" 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration"
|
||||
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
api_list = args.direction.split(',')
|
||||
filter_list = args.filter.split(',')
|
||||
filter_list.extend([''] * (len(api_list) - len(filter_list)))
|
||||
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
|
||||
list_blobs(args.list_blobs, api_list, filter_list, int(args.receipt), mask_impl=args.mask)
|
||||
else:
|
||||
write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
|
||||
write_blobs(args.output_dir, api_list, filter_list, int(args.receipt), mask_impl=args.mask)
|
||||
|
||||
Reference in New Issue
Block a user