mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
fix hd128 scratches and boost performance
This commit is contained in:
@@ -455,7 +455,7 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict
|
||||
# "kr_ktr_vr"],
|
||||
'64' : [FmhaBwdDQDKDVTileSize( 64, 128, 64, 64, 64, 64, 64, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 32, 32, 16, 1),
|
||||
"kr_ktr_vr"],
|
||||
# '128' : [FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 32, 32, 16, 32, 32, 16, 1),
|
||||
# '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# "kr_ktr_vr"],
|
||||
# '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# "kr_ktr_vr"]
|
||||
@@ -481,7 +481,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
continue
|
||||
if ((bias == "no" or bias == "alibi") and dbias == "t"):
|
||||
continue
|
||||
if (((hdim == 64 or hdim == 128) and ("wg16" in dropout)) or ((hdim == 32 or hdim == 256) and ("wg32" in dropout))):
|
||||
if (((hdim == 64) and ("wg16" in dropout)) or ((hdim != 64) and ("wg32" in dropout))):
|
||||
continue
|
||||
k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
|
||||
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
|
||||
|
||||
@@ -265,7 +265,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
|
||||
const ck_tile::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
|
||||
const ck_tile::index_t kN0 = (hdim_q > 32 & hdim_q <= 128) ? 128 : 64;
|
||||
const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64;
|
||||
const ck_tile::index_t nsplits =
|
||||
deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user