diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 1e6755c631..932f6020b6 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -866,9 +866,11 @@ def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autoge def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None: +def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (3 - len(filter_list))) + # TODO + assert optdim_list == [-1] kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) for kernel in kernels: @@ -881,9 +883,11 @@ def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> Non write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) write_bwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None: +def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (3 - len(filter_list))) + # TODO + assert optdim_list == [-1] with file_path.open('a') as f: kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 3634810b37..c31a0ce954 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -429,7 +429,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: else: return None -def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: @@ -507,6 +507,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # 2 - Flash attention integration if receipt in (2, 3): cond = dtype in ['fp16', 'bf16'] @@ -557,15 +560,15 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : str, receipt, mask_impl) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : str, receipt, mask_impl) -> None: +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: with file_path.open('a') as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index f243020dc4..dc7ef712e2 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -343,13 +343,15 @@ def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> Non def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: + assert optdim_list == [-1] api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_appendkv_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: + assert optdim_list == [-1] with file_path.open('a') as f: _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 0dccdf6bd6..ca49af1496 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -440,10 +440,10 @@ class FmhaFwdSplitKVCombinePipeline: n = f'{self.tag}' if pn != '' : n += f'_{pn}' else: n += '_npad' - + if self.F_lse == 't' : n += '_lse' else: n += '_nlse' - + if self.F_squant == 't' : n += '_squant' else: n += '_nsquant' return n @@ -819,9 +819,10 @@ def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) - file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME file_path.write_text(api_pool.api) -def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None: +def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (2 - len(filter_list))) + assert optdim_list == [-1] kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) for kernel in kernels: @@ -831,9 +832,10 @@ def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> Non write_single_kernel(kernel, output_dir) write_fwd_splitkv_api(api_pool, output_dir) -def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None: +def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (2 - len(filter_list))) + assert optdim_list == [-1] with file_path.open('a') as f: kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 25931da141..c2b0924eb3 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -30,7 +30,7 @@ handlers = dict( ) assert 0 < len(handlers) -def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None: +def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: @@ -40,10 +40,10 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.WRITE_BLOBS] - handler(output_dir, kernel_filter, receipt, mask_impl) + handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl) # list all the files that will be generated -def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None: +def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) @@ -52,7 +52,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.LIST_BLOBS] - handler(file_path, kernel_filter, receipt, mask_impl) + handler(file_path, kernel_filter, receipt, optdim_list, mask_impl) if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -113,12 +113,24 @@ if __name__ == "__main__": " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration" ) + parser.add_argument( + "--optdim", + default='-1', + required=False, + help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \ + "eg. --optdim=32,64,128,256" + ) + args = parser.parse_args() api_list = args.direction.split(',') filter_list = args.filter.split(',') filter_list.extend([''] * (len(api_list) - len(filter_list))) + optdim_list = [int(hdim) for hdim in args.optdim.split(',')] + + if len(api_list) > 1: + assert optdim_list == [-1] if args.list_blobs is not None: - list_blobs(args.list_blobs, api_list, filter_list, int(args.receipt), mask_impl=args.mask) + list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) else: - write_blobs(args.output_dir, api_list, filter_list, int(args.receipt), mask_impl=args.mask) + write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask)