diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index d2566da653..8346d58a64 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -455,7 +455,7 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict # "kr_ktr_vr"], '64' : [FmhaBwdDQDKDVTileSize( 64, 128, 64, 64, 64, 64, 64, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 32, 32, 16, 1), "kr_ktr_vr"], - # '128' : [FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 32, 32, 16, 32, 32, 16, 1), + # '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), # "kr_ktr_vr"], # '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), # "kr_ktr_vr"] @@ -481,7 +481,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> continue if ((bias == "no" or bias == "alibi") and dbias == "t"): continue - if (((hdim == 64 or hdim == 128) and ("wg16" in dropout)) or ((hdim == 32 or hdim == 256) and ("wg32" in dropout))): + if (((hdim == 64) and ("wg16" in dropout)) or ((hdim != 64) and ("wg32" in dropout))): continue k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index 9923cce5f5..49681015a9 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -265,7 +265,7 @@ bool run(const ck_tile::ArgParser& arg_parser) (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); - const ck_tile::index_t kN0 = (hdim_q > 32 & hdim_q <= 128) ? 128 : 64; + const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64; const ck_tile::index_t nsplits = deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;