Only generate specific hdim (#2120)

This commit is contained in:
rocking
2025-04-24 18:52:58 +08:00
committed by GitHub
parent 5487289fc4
commit 02ce6d39ea
5 changed files with 42 additions and 19 deletions

View File

@@ -30,7 +30,7 @@ handlers = dict(
)
assert 0 < len(handlers)
def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None:
def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
@@ -40,10 +40,10 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list :
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.WRITE_BLOBS]
handler(output_dir, kernel_filter, receipt, mask_impl)
handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl)
# list all the files that will be generated
def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None:
def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None:
assert output_file is not None
file_path = Path(output_file)
@@ -52,7 +52,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], filters_list :
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, mask_impl)
handler(file_path, kernel_filter, receipt, optdim_list, mask_impl)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
@@ -113,12 +113,24 @@ if __name__ == "__main__":
" 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration"
)
parser.add_argument(
"--optdim",
default='-1',
required=False,
help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \
"eg. --optdim=32,64,128,256"
)
args = parser.parse_args()
api_list = args.direction.split(',')
filter_list = args.filter.split(',')
filter_list.extend([''] * (len(api_list) - len(filter_list)))
optdim_list = [int(hdim) for hdim in args.optdim.split(',')]
if len(api_list) > 1:
assert optdim_list == [-1]
if args.list_blobs is not None:
list_blobs(args.list_blobs, api_list, filter_list, int(args.receipt), mask_impl=args.mask)
list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask)
else:
write_blobs(args.output_dir, api_list, filter_list, int(args.receipt), mask_impl=args.mask)
write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask)