mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
[CK_TILE][FMHA][Feature] Add support for large hdim
* root cause: fhma_bwd not support if hdim > 256 due to the use of LDS goes beyond the hardware limitations. * solution: 1. split dqdkdv kernel into 2 kernels. * 1) QGrad * 2) KGrad & VGrad * 2. reuse LDS memory. * 1). K and K^T use same LDS memory in dq kernel * 2). OGrad and OGrad^T use same LDS memory in dq kernel * 3. to avoid or reduce the number of VGPR spills, the calculation order has been readjusted, and prefetch has been disabled.
This commit is contained in:
@@ -35,7 +35,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api bwd
|
||||
--receipt 3
|
||||
--optdim 32,64,128,256
|
||||
--optdim 32,64,128,256,384,512
|
||||
# --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd...
|
||||
)
|
||||
|
||||
@@ -109,7 +109,7 @@ if(FMHA_FWD_FAST_EXP2)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero -fbracket-depth=512)
|
||||
|
||||
# conditionally enable call to the fwd_splitkv API in fmha_fwd example
|
||||
if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
|
||||
@@ -12,23 +12,23 @@ FWD_DTYPE_MAP = {
|
||||
|
||||
BWD_DTYPE_MAP = {
|
||||
"fp16": "FmhaBwdFp16",
|
||||
"bf16": "FmhaBwdBf16"
|
||||
# "bf16": "FmhaBwdBf16"
|
||||
}
|
||||
|
||||
MASK_IMPL = {
|
||||
"generic" : "ck_tile::GenericAttentionMask",
|
||||
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
|
||||
# "simplified" : "ck_tile::SimplifiedGenericAttentionMask"
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_MAP = {
|
||||
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
|
||||
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
|
||||
# "s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
|
||||
}
|
||||
|
||||
_MASK_MAP = {
|
||||
"no" : "FmhaMasks::NoMask",
|
||||
"causal" : "FmhaMasks::CausalMask",
|
||||
"generic" : "FmhaMasks::GenericMask"
|
||||
# "causal" : "FmhaMasks::CausalMask",
|
||||
# "generic" : "FmhaMasks::GenericMask"
|
||||
}
|
||||
|
||||
def get_mask_map(mask : str):
|
||||
@@ -62,8 +62,8 @@ def get_mask_check_map(mask : str):
|
||||
|
||||
BIAS_MAP = {
|
||||
"no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
|
||||
"bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
|
||||
"alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
|
||||
# "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
|
||||
# "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
|
||||
}
|
||||
|
||||
# TODO: this is ugly
|
||||
@@ -75,10 +75,10 @@ BIAS_CHECK_MAP = {
|
||||
|
||||
DROPOUT_MAP = {
|
||||
"no" : "ck_tile::BlockDropoutBwd<false, true, false>",
|
||||
"dropout_wg32" : "ck_tile::BlockDropoutBwd<true, true, false>",
|
||||
"dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd<true, true, true >",
|
||||
"dropout_wg16" : "ck_tile::BlockDropoutBwd<true, false, false>",
|
||||
"dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd<true, false, true >"
|
||||
# "dropout_wg32" : "ck_tile::BlockDropoutBwd<true, true, false>",
|
||||
# "dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd<true, true, true >",
|
||||
# "dropout_wg16" : "ck_tile::BlockDropoutBwd<true, false, false>",
|
||||
# "dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd<true, false, true >"
|
||||
}
|
||||
|
||||
DROPOUT_CHECK_MAP = {
|
||||
@@ -103,7 +103,7 @@ ROPE_CHECK_MAP = {
|
||||
|
||||
MODE_MAP = {
|
||||
"batch" : "false",
|
||||
"group" : "true"
|
||||
# "group" : "true"
|
||||
}
|
||||
|
||||
LAYOUT_MAP = {
|
||||
|
||||
@@ -351,15 +351,17 @@ class FmhaBwdDQDKDVKernel:
|
||||
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
# '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# "kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
# '64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# "kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
# '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# "kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
# '160' : [FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# "kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'512' : [FmhaBwdDQDKDVTileSize( 16, 64, 512, 16, 512, 16, 32, 512, 512, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"]
|
||||
}
|
||||
else:
|
||||
@@ -748,6 +750,7 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
|
||||
|
||||
if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o):
|
||||
continue
|
||||
|
||||
if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv):
|
||||
continue
|
||||
if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq):
|
||||
|
||||
Reference in New Issue
Block a user