Fix wrong thread starting offset

This commit is contained in:
PoYen, Chen
2024-07-18 20:02:06 +00:00
parent 23450526c0
commit 27b5141706

View File

@@ -191,8 +191,8 @@ struct BlockFmhaFwdAppendKVPipeline
constexpr index_t KPerThread = 16 / sizeof(KDataType);
static_assert(kTileSizeD % KPerThread == 0);
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
index_t start_x = (threadIdx.x % KThreadPerBlock);
if(start_x + KPerThread <= rotary_dim)
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
if((start_x + KPerThread) <= rotary_dim)
{
constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size();
static_assert(thread_buffer_size % KPerThread == 0);
@@ -207,7 +207,7 @@ struct BlockFmhaFwdAppendKVPipeline
knew_tile.thread_buf_[idx + 1] = right * cos + left * sin;
});
}
#if 0
#if defined(ENABLE_DEVICE_DEBUG_STMTS)
DEVICE_DEBUG_STMTS { printf("[DEVICE] kTileSizeD: %3d\n", kTileSizeD); }
{