mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Skip code if # of block is more than needed
This commit is contained in:
@@ -101,6 +101,8 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
const QRotarySinDramBlockWindow q_rotary_sin_dram_block_window,
|
||||
const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window,
|
||||
const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window,
|
||||
bool skip_q,
|
||||
bool skip_kv,
|
||||
void* smem_ptr,
|
||||
index_t rotary_dim = 0) const
|
||||
{
|
||||
@@ -158,206 +160,206 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
#endif
|
||||
};
|
||||
|
||||
auto knew_window =
|
||||
make_tile_window(knew_dram_block_window.get_bottom_tensor_view(),
|
||||
knew_dram_block_window.get_window_lengths(),
|
||||
knew_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKnewDramTileDistribution<Problem>());
|
||||
|
||||
auto knew_tile = [&]() {
|
||||
auto knew = load_tile(knew_window);
|
||||
return tile_elementwise_in(knew_element_func, knew);
|
||||
}();
|
||||
|
||||
// optionally apply rotary embedding to Knew
|
||||
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
|
||||
if(!skip_kv)
|
||||
{
|
||||
auto rotary_cos_window =
|
||||
make_tile_window(knew_rotary_cos_dram_block_window.get_bottom_tensor_view(),
|
||||
knew_rotary_cos_dram_block_window.get_window_lengths(),
|
||||
knew_rotary_cos_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
auto knew_window = make_tile_window(
|
||||
knew_dram_block_window, Policy::template MakeKnewDramTileDistribution<Problem>());
|
||||
|
||||
auto rotary_sin_window =
|
||||
make_tile_window(knew_rotary_sin_dram_block_window.get_bottom_tensor_view(),
|
||||
knew_rotary_sin_dram_block_window.get_window_lengths(),
|
||||
knew_rotary_sin_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
|
||||
// We assume that each thread owns contiguous elements on head dimention. And we will
|
||||
// use the distribution to enable/disable threads in order to override knew_tile content
|
||||
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t KPerThread = 16 / sizeof(KDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
{
|
||||
constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size();
|
||||
static_assert(thread_buffer_size % KPerThread == 0);
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
const auto left = type_convert<float>(knew_tile.thread_buf_[idx]);
|
||||
const auto right = type_convert<float>(knew_tile.thread_buf_[idx + 1]);
|
||||
|
||||
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx / 2]);
|
||||
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx / 2]);
|
||||
|
||||
knew_tile.thread_buf_[idx] =
|
||||
type_convert<KDataType>(left * cos - right * sin);
|
||||
knew_tile.thread_buf_[idx + 1] =
|
||||
type_convert<KDataType>(right * cos + left * sin);
|
||||
});
|
||||
}
|
||||
}
|
||||
else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED
|
||||
{
|
||||
constexpr index_t KPerThread = 8 / sizeof(KDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
{
|
||||
const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
|
||||
|
||||
auto knew_other_window = knew_window;
|
||||
move_tile_window(knew_other_window,
|
||||
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
|
||||
auto knew_other_tile = load_tile(knew_other_window);
|
||||
|
||||
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
|
||||
move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size();
|
||||
static_assert(thread_buffer_size % KPerThread == 0);
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
|
||||
const auto curr = type_convert<float>(knew_tile.thread_buf_[idx]);
|
||||
const auto other = type_convert<float>(knew_other_tile.thread_buf_[idx]);
|
||||
|
||||
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx]);
|
||||
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx]);
|
||||
|
||||
knew_tile.thread_buf_[idx] =
|
||||
type_convert<KDataType>(curr * cos + other * (is_left ? -sin : sin));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
// print_tile(knew_tile, 7);
|
||||
store_tile(k_dram_block_window, knew_tile);
|
||||
|
||||
auto vnew_window =
|
||||
make_tile_window(vnew_dram_block_window.get_bottom_tensor_view(),
|
||||
vnew_dram_block_window.get_window_lengths(),
|
||||
vnew_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeVnewDramTileDistribution<Problem>());
|
||||
|
||||
auto vnew_tile = [&]() {
|
||||
auto vnew = load_tile(vnew_window);
|
||||
return tile_elementwise_in(vnew_element_func, vnew);
|
||||
}();
|
||||
store_tile(v_dram_block_window, vnew_tile);
|
||||
|
||||
// optionally apply rotary embedding to Q
|
||||
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
|
||||
{
|
||||
auto q_window = make_tile_window(q_dram_block_window.get_bottom_tensor_view(),
|
||||
q_dram_block_window.get_window_lengths(),
|
||||
q_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = [&]() {
|
||||
auto q = load_tile(q_window);
|
||||
return tile_elementwise_in(q_element_func, q);
|
||||
auto knew_tile = [&]() {
|
||||
auto knew = load_tile(knew_window);
|
||||
return tile_elementwise_in(knew_element_func, knew);
|
||||
}();
|
||||
|
||||
auto rotary_cos_window =
|
||||
make_tile_window(q_rotary_cos_dram_block_window.get_bottom_tensor_view(),
|
||||
q_rotary_cos_dram_block_window.get_window_lengths(),
|
||||
q_rotary_cos_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
|
||||
auto rotary_sin_window =
|
||||
make_tile_window(q_rotary_sin_dram_block_window.get_bottom_tensor_view(),
|
||||
q_rotary_sin_dram_block_window.get_window_lengths(),
|
||||
q_rotary_sin_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
|
||||
// We assume that each thread owns contiguous elements on head dimention. And we will
|
||||
// use the distribution to enable/disable threads in order to override q_tile content
|
||||
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
|
||||
// optionally apply rotary embedding to Knew
|
||||
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
|
||||
{
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
auto rotary_cos_window =
|
||||
make_tile_window(knew_rotary_cos_dram_block_window,
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
|
||||
constexpr index_t KPerThread = 16 / sizeof(QDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
auto rotary_sin_window =
|
||||
make_tile_window(knew_rotary_sin_dram_block_window,
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
// We assume that each thread owns contiguous elements on head dimention. And we
|
||||
// will use the distribution to enable/disable threads in order to override
|
||||
// knew_tile content
|
||||
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size();
|
||||
static_assert(thread_buffer_size % KPerThread == 0);
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
const auto left = type_convert<float>(q_tile.thread_buf_[idx]);
|
||||
const auto right = type_convert<float>(q_tile.thread_buf_[idx + 1]);
|
||||
|
||||
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx / 2]);
|
||||
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx / 2]);
|
||||
|
||||
q_tile.thread_buf_[idx] = type_convert<KDataType>(left * cos - right * sin);
|
||||
q_tile.thread_buf_[idx + 1] =
|
||||
type_convert<KDataType>(right * cos + left * sin);
|
||||
});
|
||||
}
|
||||
}
|
||||
else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED
|
||||
{
|
||||
constexpr index_t KPerThread = 8 / sizeof(QDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
{
|
||||
const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
|
||||
|
||||
auto q_other_window = q_window;
|
||||
move_tile_window(q_other_window,
|
||||
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
|
||||
auto q_other_tile = load_tile(q_other_window);
|
||||
|
||||
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
|
||||
move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size();
|
||||
static_assert(thread_buffer_size % KPerThread == 0);
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
|
||||
const auto curr = type_convert<float>(q_tile.thread_buf_[idx]);
|
||||
const auto other = type_convert<float>(q_other_tile.thread_buf_[idx]);
|
||||
constexpr index_t KPerThread = 16 / sizeof(KDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx]);
|
||||
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx]);
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
{
|
||||
constexpr index_t thread_buffer_size =
|
||||
decltype(knew_tile.thread_buf_)::size();
|
||||
static_assert(thread_buffer_size % KPerThread == 0);
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
const auto left = type_convert<float>(knew_tile.thread_buf_[idx]);
|
||||
const auto right = type_convert<float>(knew_tile.thread_buf_[idx + 1]);
|
||||
|
||||
q_tile.thread_buf_[idx] =
|
||||
type_convert<KDataType>(curr * cos + other * (is_left ? -sin : sin));
|
||||
});
|
||||
const auto cos =
|
||||
type_convert<float>(rotary_cos_tile.thread_buf_[idx / 2]);
|
||||
const auto sin =
|
||||
type_convert<float>(rotary_sin_tile.thread_buf_[idx / 2]);
|
||||
|
||||
knew_tile.thread_buf_[idx] =
|
||||
type_convert<KDataType>(left * cos - right * sin);
|
||||
knew_tile.thread_buf_[idx + 1] =
|
||||
type_convert<KDataType>(right * cos + left * sin);
|
||||
});
|
||||
}
|
||||
}
|
||||
else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED
|
||||
{
|
||||
constexpr index_t KPerThread = 8 / sizeof(KDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
{
|
||||
const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
|
||||
|
||||
auto knew_other_window = knew_window;
|
||||
move_tile_window(knew_other_window,
|
||||
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
|
||||
auto knew_other_tile = load_tile(knew_other_window);
|
||||
|
||||
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
|
||||
move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t thread_buffer_size =
|
||||
decltype(knew_tile.thread_buf_)::size();
|
||||
static_assert(thread_buffer_size % KPerThread == 0);
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
|
||||
const auto curr = type_convert<float>(knew_tile.thread_buf_[idx]);
|
||||
const auto other =
|
||||
type_convert<float>(knew_other_tile.thread_buf_[idx]);
|
||||
|
||||
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx]);
|
||||
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx]);
|
||||
|
||||
knew_tile.thread_buf_[idx] = type_convert<KDataType>(
|
||||
curr * cos + other * (is_left ? -sin : sin));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
print_tile(q_tile, 8);
|
||||
store_tile(q_dram_block_window, q_tile);
|
||||
print_tile(knew_tile, 2);
|
||||
store_tile(k_dram_block_window, knew_tile);
|
||||
|
||||
auto vnew_window = make_tile_window(
|
||||
vnew_dram_block_window, Policy::template MakeVnewDramTileDistribution<Problem>());
|
||||
|
||||
auto vnew_tile = [&]() {
|
||||
auto vnew = load_tile(vnew_window);
|
||||
return tile_elementwise_in(vnew_element_func, vnew);
|
||||
}();
|
||||
store_tile(v_dram_block_window, vnew_tile);
|
||||
}
|
||||
|
||||
if(!skip_q)
|
||||
{
|
||||
// optionally apply rotary embedding to Q
|
||||
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
|
||||
{
|
||||
auto q_window = make_tile_window(
|
||||
q_dram_block_window, Policy::template MakeQDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = [&]() {
|
||||
auto q = load_tile(q_window);
|
||||
return tile_elementwise_in(q_element_func, q);
|
||||
}();
|
||||
|
||||
auto rotary_cos_window =
|
||||
make_tile_window(q_rotary_cos_dram_block_window,
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
|
||||
auto rotary_sin_window =
|
||||
make_tile_window(q_rotary_sin_dram_block_window,
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
|
||||
// We assume that each thread owns contiguous elements on head dimention. And we
|
||||
// will use the distribution to enable/disable threads in order to override q_tile
|
||||
// content
|
||||
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t KPerThread = 16 / sizeof(QDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
{
|
||||
constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size();
|
||||
static_assert(thread_buffer_size % KPerThread == 0);
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
const auto left = type_convert<float>(q_tile.thread_buf_[idx]);
|
||||
const auto right = type_convert<float>(q_tile.thread_buf_[idx + 1]);
|
||||
|
||||
const auto cos =
|
||||
type_convert<float>(rotary_cos_tile.thread_buf_[idx / 2]);
|
||||
const auto sin =
|
||||
type_convert<float>(rotary_sin_tile.thread_buf_[idx / 2]);
|
||||
|
||||
q_tile.thread_buf_[idx] =
|
||||
type_convert<KDataType>(left * cos - right * sin);
|
||||
q_tile.thread_buf_[idx + 1] =
|
||||
type_convert<KDataType>(right * cos + left * sin);
|
||||
});
|
||||
}
|
||||
}
|
||||
else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED
|
||||
{
|
||||
constexpr index_t KPerThread = 8 / sizeof(QDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
{
|
||||
const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
|
||||
|
||||
auto q_other_window = q_window;
|
||||
move_tile_window(q_other_window,
|
||||
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
|
||||
auto q_other_tile = load_tile(q_other_window);
|
||||
|
||||
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
|
||||
move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size();
|
||||
static_assert(thread_buffer_size % KPerThread == 0);
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
|
||||
const auto curr = type_convert<float>(q_tile.thread_buf_[idx]);
|
||||
const auto other = type_convert<float>(q_other_tile.thread_buf_[idx]);
|
||||
|
||||
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx]);
|
||||
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx]);
|
||||
|
||||
q_tile.thread_buf_[idx] = type_convert<KDataType>(
|
||||
curr * cos + other * (is_left ? -sin : sin));
|
||||
});
|
||||
}
|
||||
}
|
||||
// print_tile(q_tile, 8);
|
||||
store_tile(q_dram_block_window, q_tile);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -380,6 +382,8 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
const QRotarySinDramBlockWindow& q_rotary_sin_dram_block_window,
|
||||
const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window,
|
||||
const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window,
|
||||
bool skip_q,
|
||||
bool skip_kv,
|
||||
void* smem_ptr,
|
||||
index_t rotary_dim = 0) const
|
||||
{
|
||||
@@ -395,6 +399,8 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
q_rotary_sin_dram_block_window,
|
||||
knew_rotary_cos_dram_block_window,
|
||||
knew_rotary_sin_dram_block_window,
|
||||
skip_q,
|
||||
skip_kv,
|
||||
smem_ptr,
|
||||
rotary_dim);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user