mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
Apply RoPE to q_tile
This commit is contained in:
@@ -1098,7 +1098,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
|
||||
|
||||
// optionally apply RoPE to the q_host_ref
|
||||
if(false && 0 < rotary_dim)
|
||||
if(0 < rotary_dim)
|
||||
{
|
||||
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
|
||||
|
||||
@@ -1107,6 +1107,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
|
||||
}
|
||||
#if 1
|
||||
HOST_DEBUG_STMTS {
|
||||
printf("\n");
|
||||
for(size_t row = 0; row < q_host_ref.get_length(1) && row < 8; ++row)
|
||||
{
|
||||
printf("[HOST] q_host_ref[%3zu] = ", row);
|
||||
for(size_t col = 0; col < q_host_ref.get_length(2); ++col)
|
||||
{
|
||||
if (0 < col && col % 8 == 0) {
|
||||
printf("|");
|
||||
}
|
||||
|
||||
printf("%11.7f",
|
||||
ck_tile::type_convert<float>(q_host_ref(0, row, col)));
|
||||
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); });
|
||||
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
|
||||
@@ -1114,8 +1134,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// copy Knew to the end of K
|
||||
if(0 < seqlen_knew)
|
||||
{
|
||||
printf("\n");
|
||||
|
||||
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, seqlen_knew, hdim_q});
|
||||
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[0] / nr, i[1], i[2]); });
|
||||
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[1], i[0] / nr, i[2]); });
|
||||
@@ -1125,6 +1143,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::optional<decltype(knew_host_ref)> knew_host_ref_ro;
|
||||
#if 0
|
||||
HOST_DEBUG_STMTS {
|
||||
printf("\n");
|
||||
for(size_t row = 0; row < real_knew_host_ref->get_length(1); ++row)
|
||||
{
|
||||
printf("[HOST] real_knew_host[%3zu] = ", row);
|
||||
@@ -1157,6 +1176,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
#if 0
|
||||
HOST_DEBUG_STMTS {
|
||||
printf("\n");
|
||||
for(size_t row = 0; row < real_knew_host_ref->get_length(1); ++row)
|
||||
{
|
||||
printf("[HOST] real_knew_host_ref[%3zu] = ", row);
|
||||
|
||||
@@ -518,56 +518,110 @@ struct FmhaFwdAppendKVKernel
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
}
|
||||
}();
|
||||
constexpr auto rotary_cos_sin_dram_window_lengths =
|
||||
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD / 2>{});
|
||||
const auto rotary_cos_dram_window = [&]() {
|
||||
|
||||
constexpr auto q_rotary_cos_sin_dram_window_lengths =
|
||||
make_tuple(number<FmhaPipeline::kTileSizeS>{}, number<FmhaPipeline::kTileSizeD / 2>{});
|
||||
const auto q_rotary_cos_dram_window = [&]() {
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
const auto rotary_cos_dram = [&]() {
|
||||
const auto rotary_cos_dram_native =
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const KDataType*>(kargs.rotary_cos_ptr),
|
||||
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
|
||||
make_tuple(kargs.rotary_dim / 2, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
const auto rotary_cos_dram_native =
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const KDataType*>(kargs.rotary_cos_ptr),
|
||||
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
|
||||
make_tuple(kargs.rotary_dim / 2, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
const auto rotary_cos_dram = [&]() {
|
||||
return pad_tensor_view(rotary_cos_dram_native,
|
||||
rotary_cos_sin_dram_window_lengths,
|
||||
sequence<kPadSeqLenQ, kPadSeqLenK>{});
|
||||
q_rotary_cos_sin_dram_window_lengths,
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
rotary_cos_dram, rotary_cos_sin_dram_window_lengths, {0, 0});
|
||||
rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(rotary_cos_sin_dram_window_lengths);
|
||||
return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths);
|
||||
}
|
||||
}();
|
||||
const auto rotary_sin_dram_window = [&]() {
|
||||
const auto q_rotary_sin_dram_window = [&]() {
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
const auto rotary_sin_dram = [&]() {
|
||||
const auto rotary_sin_dram_native =
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const KDataType*>(kargs.rotary_sin_ptr),
|
||||
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
|
||||
make_tuple(kargs.rotary_dim / 2, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
const auto rotary_sin_dram_native =
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const KDataType*>(kargs.rotary_sin_ptr),
|
||||
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
|
||||
make_tuple(kargs.rotary_dim / 2, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
const auto rotary_sin_dram = [&]() {
|
||||
return pad_tensor_view(rotary_sin_dram_native,
|
||||
rotary_cos_sin_dram_window_lengths,
|
||||
sequence<kPadSeqLenQ, kPadSeqLenK>{});
|
||||
q_rotary_cos_sin_dram_window_lengths,
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
rotary_sin_dram, rotary_cos_sin_dram_window_lengths, {0, 0});
|
||||
rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(rotary_cos_sin_dram_window_lengths);
|
||||
return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths);
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto knew_rotary_cos_sin_dram_window_lengths =
|
||||
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD / 2>{});
|
||||
const auto knew_rotary_cos_dram_window = [&]() {
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
const auto rotary_cos_dram_native =
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const KDataType*>(kargs.rotary_cos_ptr),
|
||||
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
|
||||
make_tuple(kargs.rotary_dim / 2, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
const auto rotary_cos_dram = [&]() {
|
||||
return pad_tensor_view(rotary_cos_dram_native,
|
||||
knew_rotary_cos_sin_dram_window_lengths,
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths);
|
||||
}
|
||||
}();
|
||||
const auto knew_rotary_sin_dram_window = [&]() {
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
const auto rotary_sin_dram_native =
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const KDataType*>(kargs.rotary_sin_ptr),
|
||||
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
|
||||
make_tuple(kargs.rotary_dim / 2, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
const auto rotary_sin_dram = [&]() {
|
||||
return pad_tensor_view(rotary_sin_dram_native,
|
||||
knew_rotary_cos_sin_dram_window_lengths,
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -618,8 +672,10 @@ struct FmhaFwdAppendKVKernel
|
||||
knew_dram_window,
|
||||
v_dram_window,
|
||||
vnew_dram_window,
|
||||
rotary_cos_dram_window,
|
||||
rotary_sin_dram_window,
|
||||
q_rotary_cos_dram_window,
|
||||
q_rotary_sin_dram_window,
|
||||
knew_rotary_cos_dram_window,
|
||||
knew_rotary_sin_dram_window,
|
||||
smem_ptr,
|
||||
kargs.rotary_dim);
|
||||
}
|
||||
@@ -630,8 +686,10 @@ struct FmhaFwdAppendKVKernel
|
||||
knew_dram_window,
|
||||
v_dram_window,
|
||||
vnew_dram_window,
|
||||
rotary_cos_dram_window,
|
||||
rotary_sin_dram_window,
|
||||
q_rotary_cos_dram_window,
|
||||
q_rotary_sin_dram_window,
|
||||
knew_rotary_cos_dram_window,
|
||||
knew_rotary_sin_dram_window,
|
||||
smem_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,8 +84,10 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
typename QElementFunction,
|
||||
typename KnewElementFunction,
|
||||
typename VnewElementFunction,
|
||||
typename RotaryCosDramBlockWindow,
|
||||
typename RotarySinDramBlockWindow>
|
||||
typename QRotaryCosDramBlockWindow,
|
||||
typename QRotarySinDramBlockWindow,
|
||||
typename KnewRotaryCosDramBlockWindow,
|
||||
typename KnewRotarySinDramBlockWindow>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
@@ -95,8 +97,10 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
VDramBlockWindow& v_dram_block_window, // N1*K1 tile
|
||||
const VnewDramBlockWindow& vnew_dram_block_window, // N1*K1 tile
|
||||
const VnewElementFunction& vnew_element_func,
|
||||
const RotaryCosDramBlockWindow rotary_cos_dram_block_window,
|
||||
const RotarySinDramBlockWindow rotary_sin_dram_block_window,
|
||||
const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window,
|
||||
const QRotarySinDramBlockWindow q_rotary_sin_dram_block_window,
|
||||
const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window,
|
||||
const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window,
|
||||
void* smem_ptr,
|
||||
index_t rotary_dim = 0) const
|
||||
{
|
||||
@@ -169,15 +173,15 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
|
||||
{
|
||||
auto rotary_cos_window =
|
||||
make_tile_window(rotary_cos_dram_block_window.get_bottom_tensor_view(),
|
||||
rotary_cos_dram_block_window.get_window_lengths(),
|
||||
rotary_cos_dram_block_window.get_window_origin(),
|
||||
make_tile_window(knew_rotary_cos_dram_block_window.get_bottom_tensor_view(),
|
||||
knew_rotary_cos_dram_block_window.get_window_lengths(),
|
||||
knew_rotary_cos_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
|
||||
auto rotary_sin_window =
|
||||
make_tile_window(rotary_sin_dram_block_window.get_bottom_tensor_view(),
|
||||
rotary_sin_dram_block_window.get_window_lengths(),
|
||||
rotary_sin_dram_block_window.get_window_origin(),
|
||||
make_tile_window(knew_rotary_sin_dram_block_window.get_bottom_tensor_view(),
|
||||
knew_rotary_sin_dram_block_window.get_window_lengths(),
|
||||
knew_rotary_sin_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
|
||||
// We assume that each thread owns contiguous elements on head dimention. And we will
|
||||
@@ -274,15 +278,85 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
auto q = load_tile(q_window);
|
||||
return tile_elementwise_in(q_element_func, q);
|
||||
}();
|
||||
print_tile(q_tile, 8);
|
||||
/// TODO: add rotary_cos/rotary_sin windows for Q (tile size: M0xK0)
|
||||
|
||||
auto rotary_cos_window =
|
||||
make_tile_window(q_rotary_cos_dram_block_window.get_bottom_tensor_view(),
|
||||
q_rotary_cos_dram_block_window.get_window_lengths(),
|
||||
q_rotary_cos_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
|
||||
auto rotary_sin_window =
|
||||
make_tile_window(q_rotary_sin_dram_block_window.get_bottom_tensor_view(),
|
||||
q_rotary_sin_dram_block_window.get_window_lengths(),
|
||||
q_rotary_sin_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
|
||||
// We assume that each thread owns contiguous elements on head dimention. And we will
|
||||
// use the distribution to enable/disable threads in order to override knew_tile content
|
||||
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) {}
|
||||
// use the distribution to enable/disable threads in order to override q_tile content
|
||||
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t KPerThread = 16 / sizeof(QDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
{
|
||||
constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size();
|
||||
static_assert(thread_buffer_size % KPerThread == 0);
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
const auto left = type_convert<float>(q_tile.thread_buf_[idx]);
|
||||
const auto right = type_convert<float>(q_tile.thread_buf_[idx + 1]);
|
||||
|
||||
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx / 2]);
|
||||
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx / 2]);
|
||||
|
||||
q_tile.thread_buf_[idx] = type_convert<KDataType>(left * cos - right * sin);
|
||||
q_tile.thread_buf_[idx + 1] =
|
||||
type_convert<KDataType>(right * cos + left * sin);
|
||||
});
|
||||
}
|
||||
}
|
||||
else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED
|
||||
{
|
||||
}
|
||||
constexpr index_t KPerThread = 8 / sizeof(QDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
{
|
||||
const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
|
||||
|
||||
auto q_other_window = q_window;
|
||||
move_tile_window(q_other_window,
|
||||
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
|
||||
auto q_other_tile = load_tile(q_other_window);
|
||||
|
||||
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
|
||||
move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size();
|
||||
static_assert(thread_buffer_size % KPerThread == 0);
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
|
||||
const auto curr = type_convert<float>(q_tile.thread_buf_[idx]);
|
||||
const auto other = type_convert<float>(q_other_tile.thread_buf_[idx]);
|
||||
|
||||
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx]);
|
||||
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx]);
|
||||
|
||||
q_tile.thread_buf_[idx] =
|
||||
type_convert<KDataType>(curr * cos + other * (is_left ? -sin : sin));
|
||||
});
|
||||
}
|
||||
}
|
||||
print_tile(q_tile, 8);
|
||||
store_tile(q_dram_block_window, q_tile);
|
||||
}
|
||||
}
|
||||
@@ -292,16 +366,20 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
typename KnewDramBlockWindow,
|
||||
typename VDramBlockWindow,
|
||||
typename VnewDramBlockWindow,
|
||||
typename RotaryCosDramBlockWindow,
|
||||
typename RotarySinDramBlockWindow>
|
||||
typename QRotaryCosDramBlockWindow,
|
||||
typename QRotarySinDramBlockWindow,
|
||||
typename KnewRotaryCosDramBlockWindow,
|
||||
typename KnewRotarySinDramBlockWindow>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(QDramBlockWindow& q_dram_block_window,
|
||||
KDramBlockWindow& k_dram_block_window,
|
||||
const KnewDramBlockWindow& knew_dram_block_window,
|
||||
VDramBlockWindow& v_dram_block_window,
|
||||
const VnewDramBlockWindow& vnew_dram_block_window,
|
||||
const RotaryCosDramBlockWindow& rotary_cos_dram_block_window,
|
||||
const RotarySinDramBlockWindow& rotary_sin_dram_block_window,
|
||||
const QRotaryCosDramBlockWindow& q_rotary_cos_dram_block_window,
|
||||
const QRotarySinDramBlockWindow& q_rotary_sin_dram_block_window,
|
||||
const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window,
|
||||
const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window,
|
||||
void* smem_ptr,
|
||||
index_t rotary_dim = 0) const
|
||||
{
|
||||
@@ -313,8 +391,10 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
v_dram_block_window,
|
||||
vnew_dram_block_window,
|
||||
identity{},
|
||||
rotary_cos_dram_block_window,
|
||||
rotary_sin_dram_block_window,
|
||||
q_rotary_cos_dram_block_window,
|
||||
q_rotary_sin_dram_block_window,
|
||||
knew_rotary_cos_dram_block_window,
|
||||
knew_rotary_sin_dram_block_window,
|
||||
smem_ptr,
|
||||
rotary_dim);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user