mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Move thread locating logics into policy
This commit is contained in:
@@ -184,17 +184,16 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
// We assume that each thread owns contiguous elements on head dimention. And we
|
||||
// will use the distribution to enable/disable threads in order to override
|
||||
// knew_tile content
|
||||
auto [thread_start, thread_end] =
|
||||
Policy::template GetKnewThreadRangeAlongK<Problem>();
|
||||
ignore = thread_start;
|
||||
|
||||
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t KPerThread = 16 / sizeof(KDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
if(thread_end <= rotary_dim)
|
||||
{
|
||||
constexpr index_t thread_buffer_size =
|
||||
decltype(knew_tile.thread_buf_)::size();
|
||||
@@ -217,14 +216,9 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
}
|
||||
else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED
|
||||
{
|
||||
constexpr index_t KPerThread = 8 / sizeof(KDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
if(thread_end <= rotary_dim)
|
||||
{
|
||||
const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
|
||||
const bool is_left = (thread_end <= (rotary_dim / 2));
|
||||
|
||||
auto knew_other_window = knew_window;
|
||||
move_tile_window(knew_other_window,
|
||||
@@ -291,20 +285,17 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
// We assume that each thread owns contiguous elements on head dimention. And we
|
||||
// will use the distribution to enable/disable threads in order to override q_tile
|
||||
// content
|
||||
auto [thread_start, thread_end] = Policy::template GetQThreadRangeAlongK<Problem>();
|
||||
ignore = thread_start;
|
||||
|
||||
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t KPerThread = 16 / sizeof(QDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
if(thread_end <= rotary_dim)
|
||||
{
|
||||
constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size();
|
||||
static_assert(thread_buffer_size % KPerThread == 0);
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
const auto left = type_convert<float>(q_tile.thread_buf_[idx]);
|
||||
const auto right = type_convert<float>(q_tile.thread_buf_[idx + 1]);
|
||||
@@ -323,14 +314,9 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
}
|
||||
else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED
|
||||
{
|
||||
constexpr index_t KPerThread = 8 / sizeof(QDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
if(thread_end <= rotary_dim)
|
||||
{
|
||||
const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
|
||||
const bool is_left = (thread_end <= (rotary_dim / 2));
|
||||
|
||||
auto q_other_window = q_window;
|
||||
move_tile_window(q_other_window,
|
||||
@@ -344,7 +330,6 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size();
|
||||
static_assert(thread_buffer_size % KPerThread == 0);
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
|
||||
const auto curr = type_convert<float>(q_tile.thread_buf_[idx]);
|
||||
const auto other = type_convert<float>(q_other_tile.thread_buf_[idx]);
|
||||
|
||||
@@ -57,6 +57,31 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
return sizeof(KDataType) * Problem::kTileSizeSk * (Problem::kTileSizeD);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static auto GetQThreadRangeAlongK()
|
||||
{
|
||||
static_assert(Problem::RotaryEnum != BlockRotaryEmbeddingEnum::NONE);
|
||||
|
||||
if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
constexpr index_t KPerThread = 16 / sizeof(typename Problem::QDataType);
|
||||
static_assert(Problem::kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = Problem::kTileSizeD / KPerThread;
|
||||
index_t start_x = (get_thread_id() % KThreadPerBlock) * KPerThread;
|
||||
|
||||
return make_tuple(start_x, start_x + KPerThread);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t KPerThread = 8 / sizeof(typename Problem::QDataType);
|
||||
static_assert(Problem::kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = Problem::kTileSizeD / KPerThread;
|
||||
index_t start_x = (get_thread_id() % KThreadPerBlock) * KPerThread;
|
||||
|
||||
return make_tuple(start_x, start_x + KPerThread);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
|
||||
{
|
||||
@@ -91,6 +116,29 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static auto GetKnewThreadRangeAlongK()
|
||||
{
|
||||
static_assert(Problem::RotaryEnum != BlockRotaryEmbeddingEnum::NONE);
|
||||
|
||||
if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
constexpr index_t KPerThread = 16 / sizeof(typename Problem::KDataType);
|
||||
constexpr index_t KThreadPerBlock = Problem::kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
return make_tuple(start_x, start_x + KPerThread);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t KPerThread = 8 / sizeof(typename Problem::KDataType);
|
||||
constexpr index_t KThreadPerBlock = Problem::kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
return make_tuple(start_x, start_x + KPerThread);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKnewDramTileDistribution()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user