mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
Remove using partitioner for all fmha kernels (#1778)
* Remove using tile partitioner for fmha_fwd_kernel * Remove using tile partitioner for fmha_fwd_splitkv and splitkv-combine kernels * Remove using tile partitioner for fmha_fwd_appendkv kernel * Unify the format of GetTileIndex
This commit is contained in:
@@ -29,11 +29,6 @@ K0_MAX_SUBMAX_MAP = {
|
||||
256: 256
|
||||
}
|
||||
|
||||
TILE_PARTITIONER_MAP = {
|
||||
"shb" : "ck_tile::FmhaFwdTilePartitioner_SHB",
|
||||
"hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS",
|
||||
}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
@@ -90,9 +85,7 @@ using fmha_epilogue_{F_idx} =
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdKernel<{F_tile_partitioner}<fmha_shape_{F_idx}>,
|
||||
fmha_pipeline_{F_idx},
|
||||
fmha_epilogue_{F_idx}>;
|
||||
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
@@ -329,12 +322,6 @@ class FmhaFwdKernel:
|
||||
F_pipeline : FmhaFwdPipeline
|
||||
mask_impl : str
|
||||
|
||||
def get_tp(self) -> str:
|
||||
if self.F_mode == 'group':
|
||||
return 'hbs'
|
||||
else:
|
||||
return 'shb'
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
@@ -374,13 +361,12 @@ class FmhaFwdKernel:
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()])
|
||||
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \
|
||||
return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
|
||||
@property
|
||||
|
||||
@@ -46,9 +46,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl
|
||||
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
|
||||
fmha_pipeline_{F_idx}>;
|
||||
using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel<fmha_pipeline_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
|
||||
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
|
||||
@@ -355,4 +353,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
|
||||
_, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n")
|
||||
|
||||
@@ -96,9 +96,7 @@ using fmha_epilogue =
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel =
|
||||
ck_tile::FmhaFwdSplitKVKernel<ck_tile::FmhaFwdSplitKVTilePartitioner<fmha_shape>,
|
||||
fmha_pipeline,
|
||||
fmha_epilogue>;
|
||||
ck_tile::FmhaFwdSplitKVKernel<fmha_pipeline, fmha_epilogue>;
|
||||
|
||||
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
@@ -176,11 +174,7 @@ using fmha_epilogue =
|
||||
false, false>>;
|
||||
|
||||
using fmha_kernel =
|
||||
ck_tile::FmhaFwdSplitKVCombineKernel<
|
||||
ck_tile::FmhaFwdSplitKVCombineTilePartitioner<
|
||||
fmha_pipeline_problem::kM0, fmha_pipeline_problem::kN1>,
|
||||
fmha_pipeline,
|
||||
fmha_epilogue>;
|
||||
ck_tile::FmhaFwdSplitKVCombineKernel<fmha_pipeline, fmha_epilogue>;
|
||||
|
||||
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
|
||||
Reference in New Issue
Block a user