diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index d32d5a321d..5659162c97 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -872,7 +872,7 @@ struct FmhaBwdDQDKDVKernel } // skip if logical lengths are zero - if(kargs.seqlen_q == 0 || kargs.seqlen_k == 0) + if(kargs.seqlen_q == 0 && kargs.seqlen_k == 0) { return; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index e4332df930..d12310add3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -161,15 +161,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); - // check early exit if masked and no work to do. - if constexpr(FmhaMask::IsMasking) + // check early exit if no work to do. + if(num_total_loop <= 0) { - if(num_total_loop <= 0) - { - // Note: here dk_acc&dv_acc are all cleard, return it - // Note: v loaded but no fence, ignore it. - return make_tuple(dk_acc, dv_acc); - } + // Note: here dk_acc&dv_acc are all cleared, return it + return make_tuple(dk_acc, dv_acc); } KDataType* k_lds_ptr = static_cast(static_cast(static_cast(smem_ptr))); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 03ee1486da..79bf963cf7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -161,15 +161,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); - // check early exit if masked and no work to do. - if constexpr(FmhaMask::IsMasking) + // check early exit if no work to do. + if(num_total_loop <= 0) { - if(num_total_loop <= 0) - { - // Note: here dk_acc&dv_acc are all cleard, return it - // Note: v loaded but no fence, ignore it. - return make_tuple(dk_acc, dv_acc); - } + // Note: here dk_acc&dv_acc are all cleared, return it + return make_tuple(dk_acc, dv_acc); } KDataType* k_lds_ptr = static_cast(static_cast(static_cast(smem_ptr))); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 7f893a93ba..966e2ddff4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -247,15 +247,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); - // check early exit if masked and no work to do. - if constexpr(FmhaMask::IsMasking) + // check early exit if no work to do. + if(num_total_loop <= 0) { - if(num_total_loop <= 0) - { - // Note: here dk_acc&dv_acc are all cleard, return it - // Note: v loaded but no fence, ignore it. - return make_tuple(dk_acc, dv_acc); - } + // Note: here dk_acc&dv_acc are all cleared, return it + return make_tuple(dk_acc, dv_acc); } auto k_lds = make_tensor_view(