diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index 85138b0bd..f31c8024b 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -157,7 +157,8 @@ struct CausalMask : NoMask { TileShape const& tile_shape, ProblemSize const& problem_size) { - return ceil_div(get<0>(tile_shape), get<1>(tile_shape)); + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); + return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); } template