mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Extract rotary embedding logic out
This commit is contained in:
@@ -34,4 +34,75 @@ struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::HALF_ROTATED>
|
||||
static constexpr const char* name = "half";
|
||||
};
|
||||
|
||||
template <RotaryEmbeddingEnum RotaryEnum, typename ComputeDataType = float>
|
||||
struct BlockRotaryEmbedding
|
||||
{
|
||||
template <typename DistributedTensor,
|
||||
typename OtherDramBlockWindow,
|
||||
typename RotaryCosDramBlockWindow,
|
||||
typename RotarySinDramBlockWindow>
|
||||
CK_TILE_HOST_DEVICE static void apply(DistributedTensor& tile,
|
||||
OtherDramBlockWindow other_window,
|
||||
RotaryCosDramBlockWindow rotary_cos_window,
|
||||
RotarySinDramBlockWindow rotary_sin_window,
|
||||
index_t rotary_dim,
|
||||
index_t thread_end)
|
||||
{
|
||||
using DataType = typename remove_cvref_t<DistributedTensor>::DataType;
|
||||
|
||||
if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
if(thread_end <= rotary_dim)
|
||||
{
|
||||
constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
const auto left = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
|
||||
const auto right = type_convert<ComputeDataType>(tile.thread_buf_[idx + 1]);
|
||||
|
||||
const auto cos =
|
||||
type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx / 2]);
|
||||
const auto sin =
|
||||
type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx / 2]);
|
||||
|
||||
tile.thread_buf_[idx] = type_convert<DataType>(left * cos - right * sin);
|
||||
tile.thread_buf_[idx + 1] = type_convert<DataType>(right * cos + left * sin);
|
||||
});
|
||||
}
|
||||
}
|
||||
else if constexpr(RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
|
||||
{
|
||||
if(thread_end <= rotary_dim)
|
||||
{
|
||||
const bool is_left = (thread_end <= (rotary_dim / 2));
|
||||
|
||||
move_tile_window(other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
|
||||
auto other_tile = load_tile(other_window);
|
||||
|
||||
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
|
||||
move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
|
||||
const auto curr = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
|
||||
const auto other = type_convert<ComputeDataType>(other_tile.thread_buf_[idx]);
|
||||
|
||||
const auto cos =
|
||||
type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx]);
|
||||
const auto sin =
|
||||
type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx]);
|
||||
|
||||
tile.thread_buf_[idx] =
|
||||
type_convert<DataType>(curr * cos + other * (is_left ? -sin : sin));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -182,69 +182,18 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
|
||||
// We assume that each thread owns contiguous elements on head dimention. And we
|
||||
// will use the distribution to enable/disable threads in order to override
|
||||
// will use the distribution to enable/disable threads in order to override partial
|
||||
// knew_tile content
|
||||
auto [thread_start, thread_end] =
|
||||
Policy::template GetKnewThreadRangeAlongK<Problem>();
|
||||
ignore = thread_start;
|
||||
|
||||
if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
if(thread_end <= rotary_dim)
|
||||
{
|
||||
constexpr index_t thread_buffer_size =
|
||||
decltype(knew_tile.thread_buf_)::size();
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
const auto left = type_convert<float>(knew_tile.thread_buf_[idx]);
|
||||
const auto right = type_convert<float>(knew_tile.thread_buf_[idx + 1]);
|
||||
|
||||
const auto cos =
|
||||
type_convert<float>(rotary_cos_tile.thread_buf_[idx / 2]);
|
||||
const auto sin =
|
||||
type_convert<float>(rotary_sin_tile.thread_buf_[idx / 2]);
|
||||
|
||||
knew_tile.thread_buf_[idx] =
|
||||
type_convert<KDataType>(left * cos - right * sin);
|
||||
knew_tile.thread_buf_[idx + 1] =
|
||||
type_convert<KDataType>(right * cos + left * sin);
|
||||
});
|
||||
}
|
||||
}
|
||||
else // RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED
|
||||
{
|
||||
if(thread_end <= rotary_dim)
|
||||
{
|
||||
const bool is_left = (thread_end <= (rotary_dim / 2));
|
||||
|
||||
auto knew_other_window = knew_window;
|
||||
move_tile_window(knew_other_window,
|
||||
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
|
||||
auto knew_other_tile = load_tile(knew_other_window);
|
||||
|
||||
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
|
||||
move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t thread_buffer_size =
|
||||
decltype(knew_tile.thread_buf_)::size();
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
|
||||
const auto curr = type_convert<float>(knew_tile.thread_buf_[idx]);
|
||||
const auto other =
|
||||
type_convert<float>(knew_other_tile.thread_buf_[idx]);
|
||||
|
||||
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx]);
|
||||
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx]);
|
||||
|
||||
knew_tile.thread_buf_[idx] = type_convert<KDataType>(
|
||||
curr * cos + other * (is_left ? -sin : sin));
|
||||
});
|
||||
}
|
||||
}
|
||||
BlockRotaryEmbedding<RotaryEnum>::apply(knew_tile,
|
||||
knew_window,
|
||||
rotary_cos_window,
|
||||
rotary_sin_window,
|
||||
rotary_dim,
|
||||
thread_end);
|
||||
}
|
||||
print_tile(knew_tile, 2);
|
||||
store_tile(k_dram_block_window, knew_tile);
|
||||
@@ -281,65 +230,14 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
|
||||
// We assume that each thread owns contiguous elements on head dimention. And we
|
||||
// will use the distribution to enable/disable threads in order to override q_tile
|
||||
// content
|
||||
// will use the distribution to enable/disable threads in order to override partial
|
||||
// q_tile content
|
||||
auto [thread_start, thread_end] = Policy::template GetQThreadRangeAlongK<Problem>();
|
||||
ignore = thread_start;
|
||||
|
||||
if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
BlockRotaryEmbedding<RotaryEnum>::apply(
|
||||
q_tile, q_window, rotary_cos_window, rotary_sin_window, rotary_dim, thread_end);
|
||||
|
||||
if(thread_end <= rotary_dim)
|
||||
{
|
||||
constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size();
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
const auto left = type_convert<float>(q_tile.thread_buf_[idx]);
|
||||
const auto right = type_convert<float>(q_tile.thread_buf_[idx + 1]);
|
||||
|
||||
const auto cos =
|
||||
type_convert<float>(rotary_cos_tile.thread_buf_[idx / 2]);
|
||||
const auto sin =
|
||||
type_convert<float>(rotary_sin_tile.thread_buf_[idx / 2]);
|
||||
|
||||
q_tile.thread_buf_[idx] =
|
||||
type_convert<KDataType>(left * cos - right * sin);
|
||||
q_tile.thread_buf_[idx + 1] =
|
||||
type_convert<KDataType>(right * cos + left * sin);
|
||||
});
|
||||
}
|
||||
}
|
||||
else // RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED
|
||||
{
|
||||
if(thread_end <= rotary_dim)
|
||||
{
|
||||
const bool is_left = (thread_end <= (rotary_dim / 2));
|
||||
|
||||
auto q_other_window = q_window;
|
||||
move_tile_window(q_other_window,
|
||||
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
|
||||
auto q_other_tile = load_tile(q_other_window);
|
||||
|
||||
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
|
||||
move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size();
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
|
||||
const auto curr = type_convert<float>(q_tile.thread_buf_[idx]);
|
||||
const auto other = type_convert<float>(q_other_tile.thread_buf_[idx]);
|
||||
|
||||
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx]);
|
||||
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx]);
|
||||
|
||||
q_tile.thread_buf_[idx] = type_convert<KDataType>(
|
||||
curr * cos + other * (is_left ? -sin : sin));
|
||||
});
|
||||
}
|
||||
}
|
||||
// print_tile(q_tile, 8);
|
||||
store_tile(q_dram_block_window, q_tile);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user