From 9d165a3b8ef446a7ff3db198413f82bcb83f46fe Mon Sep 17 00:00:00 2001 From: Taebum Kim Date: Sat, 31 May 2025 11:51:18 +0900 Subject: [PATCH] Handle get_masked_trip_count for small length in fmha example (#2292) * handle get_masked_trip_count for small length * Update examples/77_blackwell_fmha/collective/fmha_fusion.hpp Co-authored-by: Vijay Thakkar * Update examples/77_blackwell_fmha/collective/fmha_fusion.hpp Co-authored-by: Vijay Thakkar --------- Co-authored-by: Vijay Thakkar --- examples/77_blackwell_fmha/collective/fmha_fusion.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index 85138b0bd..f31c8024b 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -157,7 +157,8 @@ struct CausalMask : NoMask { TileShape const& tile_shape, ProblemSize const& problem_size) { - return ceil_div(get<0>(tile_shape), get<1>(tile_shape)); + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); + return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); } template