diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index fec294ea5e..ab3fea23ed 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -191,8 +191,8 @@ struct BlockFmhaFwdAppendKVPipeline constexpr index_t KPerThread = 16 / sizeof(KDataType); static_assert(kTileSizeD % KPerThread == 0); constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; - index_t start_x = (threadIdx.x % KThreadPerBlock); - if(start_x + KPerThread <= rotary_dim) + index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; + if((start_x + KPerThread) <= rotary_dim) { constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size(); static_assert(thread_buffer_size % KPerThread == 0); @@ -207,7 +207,7 @@ struct BlockFmhaFwdAppendKVPipeline knew_tile.thread_buf_[idx + 1] = right * cos + left * sin; }); } -#if 0 +#if defined(ENABLE_DEVICE_DEBUG_STMTS) DEVICE_DEBUG_STMTS { printf("[DEVICE] kTileSizeD: %3d\n", kTileSizeD); } {