From d12c6417a0b353c541df4025ce9be683ca8ef41e Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Mon, 17 Feb 2025 14:29:25 +0800 Subject: [PATCH] Ck tile/paged attention workaround (#1894) * Correction in GetRangeAlongX() * Work-around to solve the failures in test_paged_attention_ck in xformers [ROCm/composable_kernel commit: a3757a5f9c40c1c8ff23e54c5b99c5e059ed1c39] --- include/ck_tile/ops/fmha/block/block_masking.hpp | 2 +- include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 1569c93565..726543b97a 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -310,7 +310,7 @@ struct SimplifiedGenericAttentionMask const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits)); const index_t split_start = x_per_split * i_split; - const index_t split_end = split_start + x_per_split; + const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), ck_tile::min(origin_end, split_end)); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 92dc2bac3f..14d0596287 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -742,7 +742,7 @@ struct FmhaFwdSplitKVKernel return pad_tensor_view( v_dram_transposed, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } else {