diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index f43de4573a..ce8493663f 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -372,7 +372,7 @@ struct SimplifiedGenericAttentionMask // index_t x_end = min(i_y + x, x_total); bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad - bool bottom_left_edge = i_y_end > (i_x + y); + bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad // bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now return top_right_edge || bottom_left_edge; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index a013ee3d57..d867772a1f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -501,9 +501,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsBlockDescriptor() { - using QGradDataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; constexpr index_t Banks = 32; // TODO: need change based on arch - constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QGradDataType); + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(OGradDataType); constexpr index_t kKPack = GetSmemKPackOGrad(); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kKPerBlock = [&]() {