diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 4dffdf836a..30b524d606 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -653,7 +653,7 @@ class FmhaBwdApiTrait: return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile, F_spad=self.spad, F_skpad=self.skpad, F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout, F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, F_pipeline=self.pipeline, mask_impl=self.mask_impl) - + @property def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel: # TODO: we don't support tuning yet, so pick up one value for pad/occupancy @@ -709,7 +709,7 @@ class FmhaBwdApiPool: per_dtypes += ' (void)t ; (void)s ; (void)a;' return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) -def get_bwd_blobs(filter_list: str, receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]: +def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]: if filter_list == '': filter_list = '*@*@*' filter_list = filter_list.split('@') @@ -743,13 +743,16 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl) -> Tuple[FmhaBwdApiPool, if (dpad == "t" or dvpad == "t"): ppl = d[hdim_str][2] t = FmhaBwdApiTrait(idx=0, pipeline=ppl, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad=spad, spad1=spad1, skpad=skpad, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl) - + if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o): continue if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv): continue if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # Flash attention integration if receipt == 2: @@ -803,9 +806,7 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl) -> Tuple[FmhaBwdApiPool, return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys()) def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: - assert optdim_list == [-1] # TODO - - api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl) + api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl, optdim_list) (output_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) for k in kernels_dot_do_o: (output_dir / k.filename).write_text(k.template) @@ -816,10 +817,8 @@ def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask def list_blobs(file_path: Path, filter_list: str, receipt, optdim_list, mask_impl) -> None: - assert optdim_list == [-1] # TODO - _, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs( - filter_list, receipt, mask_impl + filter_list, receipt, mask_impl, optdim_list ) with file_path.open("a") as f: for k in kernels_dot_do_o: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 78cec40aa8..730641a6b0 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -27,6 +27,7 @@ K0_MAX_SUBMAX_MAP = { 64 : 64, 96 : 128, 128: 128, + 192: 192, 256: 256 } @@ -504,11 +505,11 @@ class KernelComponentFactory: return { (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ### (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ### (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ### (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } elif dtype == 'fp8' or dtype == 'bf8': diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 517e84f380..2e5bc2bd3d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -273,7 +273,7 @@ def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]: else: return None -def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: +def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]: @@ -326,6 +326,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # 2 - Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] @@ -334,7 +337,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> continue # PyTorch integration elif receipt == 4: - cond = dtype in ['fp16, bf16'] + cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' if not cond: continue @@ -350,16 +353,14 @@ def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) (autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api) def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: - assert optdim_list == [-1] - api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) + api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_appendkv_api(api_pool, output_dir) def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: - assert optdim_list == [-1] with file_path.open('a') as f: - _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index edc1532a05..5b35e7f0bd 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -637,9 +637,9 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: return { '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - ### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - ### '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -656,9 +656,9 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d return { '32' : FmhaFwdSplitKVCombineTileSize(32, -1), '64' : FmhaFwdSplitKVCombineTileSize(32, -1), - ### '96' : FmhaFwdSplitKVCombineTileSize(32, -1), + '96' : FmhaFwdSplitKVCombineTileSize(32, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, -1), - ### '160' : FmhaFwdSplitKVCombineTileSize(32, -1), + '160' : FmhaFwdSplitKVCombineTileSize(32, -1), '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -670,7 +670,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d else: return None -def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: +def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: Pipeline = FmhaFwdSplitKVPipeline Kernel = FmhaFwdSplitKVKernel @@ -746,6 +746,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] @@ -783,7 +786,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> return (api_pool, gen) -def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaFwdSplitKVCombineKernel]: +def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim_list) -> List[FmhaFwdSplitKVCombineKernel]: Pipeline = FmhaFwdSplitKVCombinePipeline Kernel = FmhaFwdSplitKVCombineKernel @@ -830,6 +833,9 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # Aiter(mha_varlen_fwd) integration if receipt == 200: cond = dtype in ['fp16', 'bf16'] @@ -855,12 +861,11 @@ def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) - def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (2 - len(filter_list))) - assert optdim_list == [-1] - kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) for kernel in kernels: write_single_kernel(kernel, output_dir) - api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl) + api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_splitkv_api(api_pool, output_dir) @@ -868,13 +873,12 @@ def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (2 - len(filter_list))) - assert optdim_list == [-1] with file_path.open('a') as f: - kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl) + _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index c611618824..0317330511 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -126,9 +126,6 @@ if __name__ == "__main__": filter_list.extend([''] * (len(api_list) - len(filter_list))) optdim_list = [int(hdim) for hdim in args.optdim.split(',')] - if len(api_list) > 1: - assert optdim_list == [-1] - if args.list_blobs is not None: list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) else: