mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Add receipt 10~12 for codegen of aiter integration (#1877)
* Add receipt for aiter integration * update receipt * Add hdim 96 instances * Revert "Add hdim 96 instances" This reverts commitf339449f54. [ROCm/composable_kernel commit:f0d49d14fc]
This commit is contained in:
@@ -499,14 +499,14 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
cond &= dpad == dvpad
|
||||
if not cond:
|
||||
continue
|
||||
if receipt == 3:
|
||||
elif receipt == 3:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dpad == dvpad
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
if receipt == 4:
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'bias']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
@@ -514,6 +514,24 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 10:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 11:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_dq_dk_dv_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
|
||||
@@ -494,13 +494,29 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
if receipt == 4:
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 10:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 11:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
|
||||
@@ -326,7 +326,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if kernel_filter != None:
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if receipt == 2:
|
||||
if receipt in (2, 12):
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
if not cond:
|
||||
|
||||
@@ -268,7 +268,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
|
||||
// get combine kernel tile sizes
|
||||
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
|
||||
constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType, /*F_bn1=*/32>::kM0;
|
||||
@@ -712,6 +712,22 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 11:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 12:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ class HandlerId(IntEnum):
|
||||
LIST_BLOBS = 0
|
||||
WRITE_BLOBS = 1
|
||||
|
||||
# inspect all modules under 'codegen.ops' and register API handlers
|
||||
# inspect all modules under 'codegen.ops' and register API handlers
|
||||
ops = []
|
||||
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
|
||||
full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
|
||||
@@ -104,7 +104,11 @@ if __name__ == "__main__":
|
||||
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
|
||||
" 1: generate more instance to cover all hdim\n" + \
|
||||
" 2: Only generate instance for Flash attention integration\n" + \
|
||||
" 4: Only generate instance for PyTorch integration"
|
||||
" 4: Only generate instance for PyTorch integration\n" + \
|
||||
" 10: Only generate instance for Aiter(mha_fwd, mha_bwd) integration\n" + \
|
||||
" 11: Only generate instance for Aiter(mha_varlen_fwd, mha_varlen_bwd) integration\n" + \
|
||||
" 12: Only generate instance for Aiter(mha_fwd_kvcache) integration"
|
||||
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user