diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 4c23250d05..17f9c64843 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -412,13 +412,19 @@ class FmhaBwdDQDKDVKernel: pn = pad_name() n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}' if pn != '' : n += f'_{pn}' - if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_bias != 'no' : + n += f'_{self.F_bias}' + else: + n += '_nbias' if self.F_dbias == 't' : n += '_dbias' if self.F_mask[0:2] == 's_': if self.F_mask == 's_mask': n += f'_mask' else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - if self.F_dropout != 'no' : n += f'_{self.F_dropout}' + if self.F_dropout != 'no' : + n += f'_{self.F_dropout}' + else: + n += '_ndropout' if self.F_deterministic == 't' : n += '_deterministic' return n @@ -489,7 +495,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic) - if kernel_filter != None: + if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue # Flash attention integration @@ -517,23 +523,19 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if not cond: continue # Aiter (mha_bwd) integration - elif receipt == 10: + elif receipt == 300: cond = dtype in ['fp16', 'bf16'] cond &= mode == "batch" - cond &= bias in ['no', 'alibi'] cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] cond &= dpad == dvpad - cond &= deterministic == "t" if not cond: continue # Aiter (mha_varlen_bwd) integration - elif receipt == 11: + elif receipt == 400: cond = dtype in ['fp16', 'bf16'] cond &= mode == "group" - cond &= bias in ['no', 'alibi'] cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] cond &= dpad == dvpad - cond &= deterministic == "t" if not cond: continue api_pool.register_dq_dk_dv_traits(k.api_trait()) @@ -638,7 +640,7 @@ class FmhaBwdOGradDotOKernel: def filename(self) -> str: return self.name + ".cpp" -def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: +def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdOGradDotOKernel]: # TODO: we don't support tuning yet, so pick up one value for pad/occupancy # support this in future def get_occupancy(dtype, hdim): @@ -657,6 +659,21 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: k = FmhaBwdOGradDotOKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_spad=spad, F_dvpad=dvpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim)) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + # Aiter (mha_bwd) integration + if receipt == 300: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + if not cond: + continue + # Aiter (mha_varlen_bwd) integration + elif receipt == 400: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + if not cond: + continue gen.append(k) return gen @@ -773,7 +790,7 @@ class FmhaBwdConvertQGradKernel: def filename(self) -> str: return self.name + ".cpp" -def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]: +def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdConvertQGradKernel]: # TODO: we don't support tuning yet, so pick up one value for pad/occupancy # support this in future def get_occupancy(dtype, hdim): @@ -792,6 +809,21 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]: continue k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0, F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + # Aiter (mha_bwd) integration + if receipt == 300: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + if not cond: + continue + # Aiter (mha_varlen_bwd) integration + elif receipt == 400: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + if not cond: + continue gen.append(k) return gen @@ -808,27 +840,33 @@ def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autoge def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: - kernels = get_bwd_dot_do_o_blobs() +def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None: + filter_list = filter_list.split('@') + filter_list.extend([''] * (3 - len(filter_list))) + + kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) for kernel in kernels: write_single_bwd_dot_do_o_kernel(kernel, output_dir) - kernels = get_bwd_convert_dq_blobs() + kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt) for kernel in kernels: write_single_bwd_convert_dq_kernel(kernel, output_dir) - api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + api_pool, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl) for kernel in kernels: write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) write_bwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None: + filter_list = filter_list.split('@') + filter_list.extend([''] * (3 - len(filter_list))) + with file_path.open('a') as f: - kernels = get_bwd_dot_do_o_blobs() + kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - kernels = get_bwd_convert_dq_blobs() + kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index b72627ed5d..79ace6d2c3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -233,13 +233,22 @@ class FmhaFwdPipeline: pn = pad_name() n = f'{self.tag}_v{self.F_vlayout[0]}' if pn != '' : n += f'_{pn}' - if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_bias != 'no' : + n += f'_{self.F_bias}' + else: + n += '_nbias' if self.F_mask[0:2] == 's_': if self.F_mask == 's_mask': n += f'_mask' else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - if self.F_lse == 't' : n += '_lse' - if self.F_dropout == 't' : n += '_dropout' + if self.F_lse == 't' : + n += '_lse' + else: + n += '_nlse' + if self.F_dropout == 't' : + n += '_dropout' + else: + n += '_ndropout' if self.F_squant == 't' : n += '_squant' return n @@ -484,7 +493,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm F_tile=tile, F_pipeline=pipeline, mask_impl=mask_impl) - if kernel_filter != None: + if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue # 2 - Flash attention integration @@ -504,20 +513,18 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm if not cond: continue # Aiter(mha_fwd) integration - elif receipt == 10: + elif receipt == 100: cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" + cond &= mode == 'batch' cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_squant == 'f' if not cond: continue # Aiter(mha_varlen_fwd) integration - elif receipt == 11: + elif receipt == 200: cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" + cond &= mode == 'group' cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_squant == 'f' if not cond: continue @@ -532,13 +539,13 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def write_blobs(output_dir : Path, kernel_filter : str, receipt, mask_impl) -> None: api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(file_path : Path, kernel_filter : str, receipt, mask_impl) -> None: with file_path.open('a') as f: _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index f8a89448ba..16048e3fb6 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -323,12 +323,11 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> F_tile=tile, F_pipeline=pipeline, mask_impl=mask_impl) - if kernel_filter != None: + if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue # 2 - Flash attention integration - # 12 - Aiter(mha_fwd_kvcache) integration - if receipt in (2, 12): + if receipt == 2: cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' if not cond: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index c0ca666b11..b4eea36e86 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -397,14 +397,23 @@ class FmhaFwdSplitKVPipeline: pn = pad_name() n = f'{self.tag}_v{self.F_vlayout[0]}' if pn != '' : n += f'_{pn}' - if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_bias != 'no' : + n += f'_{self.F_bias}' + else: + n += '_nbias' if self.F_mask[0:2] == 's_': if self.F_mask == 's_mask': n += f'_mask' else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - if self.F_lse == 't' : n += '_lse' + if self.F_lse == 't' : + n += '_lse' + else: + n += '_nlse' if self.F_squant == 't' : n += '_squant' - if self.F_pagedkv == 't' : n += '_pagedkv' + if self.F_pagedkv == 't' : + n += '_pagedkv' + else: + n += '_npagedkv' return n @dataclass @@ -702,7 +711,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> F_tile=tile, F_pipeline=pipeline, mask_impl=mask_impl) - if kernel_filter != None: + if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue # Flash attention integration @@ -714,20 +723,10 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if not cond: continue # Aiter(mha_varlen_fwd) integration - elif receipt == 11: + elif receipt == 200: cond = dtype in ['fp16', 'bf16'] cond &= mode == "group" cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] - cond &= pipeline.F_squant == 'f' - if not cond: - continue - # Aiter(mha_fwd_kvcache) integration - elif receipt == 12: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_squant == 'f' if not cond: continue @@ -780,9 +779,15 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis F_mode=mode, F_tile=tile, F_pipeline=pipeline) - if kernel_filter != None: + if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + # Aiter(mha_varlen_fwd) integration + if receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + if not cond: + continue gen.append(k) return gen @@ -794,21 +799,27 @@ def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) - file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME file_path.write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: - kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt) +def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None: + filter_list = filter_list.split('@') + filter_list.extend([''] * (2 - len(filter_list))) + + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) for kernel in kernels: write_single_kernel(kernel, output_dir) - api_pool, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) + api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_splitkv_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None: + filter_list = filter_list.split('@') + filter_list.extend([''] * (2 - len(filter_list))) + with file_path.open('a') as f: - kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt) + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 0c2cef1ce7..0d35db14d4 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -30,7 +30,7 @@ handlers = dict( ) assert 0 < len(handlers) -def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: +def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: @@ -38,19 +38,19 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : output_dir.mkdir(parents=True, exist_ok=True) - for api in api_list: + for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.WRITE_BLOBS] handler(output_dir, kernel_filter, receipt, mask_impl) # list all the files that will be generated -def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) # create an empty file / drop its contents if it exists open(file_path, "w").close() - for api in api_list: + for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.LIST_BLOBS] handler(file_path, kernel_filter, receipt, mask_impl) @@ -84,6 +84,7 @@ if __name__ == "__main__": parser.add_argument( "-f", "--filter", + default='', required=False, help="filter out kernels that need to generate, using fnmatch module" ) @@ -105,15 +106,19 @@ if __name__ == "__main__": " 1: generate more instance to cover all hdim\n" + \ " 2: Only generate instance for Flash attention integration\n" + \ " 4: Only generate instance for PyTorch integration\n" + \ - " 10: Only generate instance for Aiter(mha_fwd, mha_bwd) integration\n" + \ - " 11: Only generate instance for Aiter(mha_varlen_fwd, mha_varlen_bwd) integration\n" + \ - " 12: Only generate instance for Aiter(mha_fwd_kvcache) integration" - + " 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \ + " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \ + " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \ + " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration" + ) args = parser.parse_args() api_list = args.direction.split(',') + filter_list = args.filter.split(',') + filter_list.extend([''] * (len(api_list) - len(filter_list))) + if args.list_blobs is not None: - list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask) + list_blobs(args.list_blobs, api_list, filter_list, int(args.receipt), mask_impl=args.mask) else: - write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask) + write_blobs(args.output_dir, api_list, filter_list, int(args.receipt), mask_impl=args.mask)