diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index 03e1537c5d..bc0cf1e25b 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -122,7 +122,7 @@ struct mask_info v1 = atoi(v.substr(found_1 + 1).c_str()); sink = 0; } - tmp.type = mask_enum::mask_bottom_right; + tmp.type = (v0 >= 0) ? mask_enum::window_generic : mask_enum::mask_bottom_right; auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( v0, v1, sink, y_total, x_total, false); tmp.y = r.at(ck_tile::number<0>{});