mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Fix sliding window mask: use window_generic when left >= 0
mask_info::decode('b:left,right,sink') always created mask_bottom_right
(IsLocal=false) which ignores the left window boundary. For sliding
window attention (left >= 0), use window_generic (IsLocal=true) so the
kernel respects the left boundary.
Fixes: CK split-KV producing identical results with and without sliding
window. Now 724/724 shapes pass correctness vs Triton.
Made-with: Cursor
This commit is contained in:
@@ -122,7 +122,7 @@ struct mask_info
|
|||||||
v1 = atoi(v.substr(found_1 + 1).c_str());
|
v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||||
sink = 0;
|
sink = 0;
|
||||||
}
|
}
|
||||||
tmp.type = mask_enum::mask_bottom_right;
|
tmp.type = (v0 >= 0) ? mask_enum::window_generic : mask_enum::mask_bottom_right;
|
||||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||||
v0, v1, sink, y_total, x_total, false);
|
v0, v1, sink, y_total, x_total, false);
|
||||||
tmp.y = r.at(ck_tile::number<0>{});
|
tmp.y = r.at(ck_tile::number<0>{});
|
||||||
|
|||||||
Reference in New Issue
Block a user