Apply RoPE to q_tile

This commit is contained in:
PoYen, Chen
2024-07-23 03:54:11 +00:00
parent e88253a2f4
commit 48c70720b5
3 changed files with 214 additions and 56 deletions

View File

@@ -1098,7 +1098,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
// optionally apply RoPE to the q_host_ref
if(false && 0 < rotary_dim)
if(0 < rotary_dim)
{
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
@@ -1107,6 +1107,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
}
#if 1
HOST_DEBUG_STMTS {
printf("\n");
for(size_t row = 0; row < q_host_ref.get_length(1) && row < 8; ++row)
{
printf("[HOST] q_host_ref[%3zu] = ", row);
for(size_t col = 0; col < q_host_ref.get_length(2); ++col)
{
if (0 < col && col % 8 == 0) {
printf("|");
}
printf("%11.7f",
ck_tile::type_convert<float>(q_host_ref(0, row, col)));
}
printf("\n");
}
}
#endif
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); });
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
@@ -1114,8 +1134,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
// copy Knew to the end of K
if(0 < seqlen_knew)
{
printf("\n");
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, seqlen_knew, hdim_q});
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[0] / nr, i[1], i[2]); });
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[1], i[0] / nr, i[2]); });
@@ -1125,6 +1143,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::optional<decltype(knew_host_ref)> knew_host_ref_ro;
#if 0
HOST_DEBUG_STMTS {
printf("\n");
for(size_t row = 0; row < real_knew_host_ref->get_length(1); ++row)
{
printf("[HOST] real_knew_host[%3zu] = ", row);
@@ -1157,6 +1176,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
#if 0
HOST_DEBUG_STMTS {
printf("\n");
for(size_t row = 0; row < real_knew_host_ref->get_length(1); ++row)
{
printf("[HOST] real_knew_host_ref[%3zu] = ", row);

View File

@@ -518,56 +518,110 @@ struct FmhaFwdAppendKVKernel
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
}();
constexpr auto rotary_cos_sin_dram_window_lengths =
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD / 2>{});
const auto rotary_cos_dram_window = [&]() {
constexpr auto q_rotary_cos_sin_dram_window_lengths =
make_tuple(number<FmhaPipeline::kTileSizeS>{}, number<FmhaPipeline::kTileSizeD / 2>{});
const auto q_rotary_cos_dram_window = [&]() {
if constexpr(kApplyRoPE)
{
const auto rotary_cos_dram = [&]() {
const auto rotary_cos_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const KDataType*>(kargs.rotary_cos_ptr),
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
make_tuple(kargs.rotary_dim / 2, 1),
number<8>{},
number<1>{});
const auto rotary_cos_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const KDataType*>(kargs.rotary_cos_ptr),
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
make_tuple(kargs.rotary_dim / 2, 1),
number<8>{},
number<1>{});
const auto rotary_cos_dram = [&]() {
return pad_tensor_view(rotary_cos_dram_native,
rotary_cos_sin_dram_window_lengths,
sequence<kPadSeqLenQ, kPadSeqLenK>{});
q_rotary_cos_sin_dram_window_lengths,
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
return make_tile_window(
rotary_cos_dram, rotary_cos_sin_dram_window_lengths, {0, 0});
rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {0, 0});
}
else
{
return make_null_tile_window(rotary_cos_sin_dram_window_lengths);
return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths);
}
}();
const auto rotary_sin_dram_window = [&]() {
const auto q_rotary_sin_dram_window = [&]() {
if constexpr(kApplyRoPE)
{
const auto rotary_sin_dram = [&]() {
const auto rotary_sin_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const KDataType*>(kargs.rotary_sin_ptr),
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
make_tuple(kargs.rotary_dim / 2, 1),
number<8>{},
number<1>{});
const auto rotary_sin_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const KDataType*>(kargs.rotary_sin_ptr),
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
make_tuple(kargs.rotary_dim / 2, 1),
number<8>{},
number<1>{});
const auto rotary_sin_dram = [&]() {
return pad_tensor_view(rotary_sin_dram_native,
rotary_cos_sin_dram_window_lengths,
sequence<kPadSeqLenQ, kPadSeqLenK>{});
q_rotary_cos_sin_dram_window_lengths,
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
return make_tile_window(
rotary_sin_dram, rotary_cos_sin_dram_window_lengths, {0, 0});
rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {0, 0});
}
else
{
return make_null_tile_window(rotary_cos_sin_dram_window_lengths);
return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths);
}
}();
constexpr auto knew_rotary_cos_sin_dram_window_lengths =
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD / 2>{});
const auto knew_rotary_cos_dram_window = [&]() {
if constexpr(kApplyRoPE)
{
const auto rotary_cos_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const KDataType*>(kargs.rotary_cos_ptr),
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
make_tuple(kargs.rotary_dim / 2, 1),
number<8>{},
number<1>{});
const auto rotary_cos_dram = [&]() {
return pad_tensor_view(rotary_cos_dram_native,
knew_rotary_cos_sin_dram_window_lengths,
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}();
return make_tile_window(
rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {0, 0});
}
else
{
return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths);
}
}();
const auto knew_rotary_sin_dram_window = [&]() {
if constexpr(kApplyRoPE)
{
const auto rotary_sin_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const KDataType*>(kargs.rotary_sin_ptr),
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
make_tuple(kargs.rotary_dim / 2, 1),
number<8>{},
number<1>{});
const auto rotary_sin_dram = [&]() {
return pad_tensor_view(rotary_sin_dram_native,
knew_rotary_cos_sin_dram_window_lengths,
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}();
return make_tile_window(
rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {0, 0});
}
else
{
return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths);
}
}();
@@ -618,8 +672,10 @@ struct FmhaFwdAppendKVKernel
knew_dram_window,
v_dram_window,
vnew_dram_window,
rotary_cos_dram_window,
rotary_sin_dram_window,
q_rotary_cos_dram_window,
q_rotary_sin_dram_window,
knew_rotary_cos_dram_window,
knew_rotary_sin_dram_window,
smem_ptr,
kargs.rotary_dim);
}
@@ -630,8 +686,10 @@ struct FmhaFwdAppendKVKernel
knew_dram_window,
v_dram_window,
vnew_dram_window,
rotary_cos_dram_window,
rotary_sin_dram_window,
q_rotary_cos_dram_window,
q_rotary_sin_dram_window,
knew_rotary_cos_dram_window,
knew_rotary_sin_dram_window,
smem_ptr);
}
}

View File

@@ -84,8 +84,10 @@ struct BlockFmhaFwdAppendKVPipeline
typename QElementFunction,
typename KnewElementFunction,
typename VnewElementFunction,
typename RotaryCosDramBlockWindow,
typename RotarySinDramBlockWindow>
typename QRotaryCosDramBlockWindow,
typename QRotarySinDramBlockWindow,
typename KnewRotaryCosDramBlockWindow,
typename KnewRotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE auto
operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile
const QElementFunction& q_element_func,
@@ -95,8 +97,10 @@ struct BlockFmhaFwdAppendKVPipeline
VDramBlockWindow& v_dram_block_window, // N1*K1 tile
const VnewDramBlockWindow& vnew_dram_block_window, // N1*K1 tile
const VnewElementFunction& vnew_element_func,
const RotaryCosDramBlockWindow rotary_cos_dram_block_window,
const RotarySinDramBlockWindow rotary_sin_dram_block_window,
const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window,
const QRotarySinDramBlockWindow q_rotary_sin_dram_block_window,
const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window,
const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window,
void* smem_ptr,
index_t rotary_dim = 0) const
{
@@ -169,15 +173,15 @@ struct BlockFmhaFwdAppendKVPipeline
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
{
auto rotary_cos_window =
make_tile_window(rotary_cos_dram_block_window.get_bottom_tensor_view(),
rotary_cos_dram_block_window.get_window_lengths(),
rotary_cos_dram_block_window.get_window_origin(),
make_tile_window(knew_rotary_cos_dram_block_window.get_bottom_tensor_view(),
knew_rotary_cos_dram_block_window.get_window_lengths(),
knew_rotary_cos_dram_block_window.get_window_origin(),
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
auto rotary_sin_window =
make_tile_window(rotary_sin_dram_block_window.get_bottom_tensor_view(),
rotary_sin_dram_block_window.get_window_lengths(),
rotary_sin_dram_block_window.get_window_origin(),
make_tile_window(knew_rotary_sin_dram_block_window.get_bottom_tensor_view(),
knew_rotary_sin_dram_block_window.get_window_lengths(),
knew_rotary_sin_dram_block_window.get_window_origin(),
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
// We assume that each thread owns contiguous elements on head dimention. And we will
@@ -274,15 +278,85 @@ struct BlockFmhaFwdAppendKVPipeline
auto q = load_tile(q_window);
return tile_elementwise_in(q_element_func, q);
}();
print_tile(q_tile, 8);
/// TODO: add rotary_cos/rotary_sin windows for Q (tile size: M0xK0)
auto rotary_cos_window =
make_tile_window(q_rotary_cos_dram_block_window.get_bottom_tensor_view(),
q_rotary_cos_dram_block_window.get_window_lengths(),
q_rotary_cos_dram_block_window.get_window_origin(),
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
auto rotary_sin_window =
make_tile_window(q_rotary_sin_dram_block_window.get_bottom_tensor_view(),
q_rotary_sin_dram_block_window.get_window_lengths(),
q_rotary_sin_dram_block_window.get_window_origin(),
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
// We assume that each thread owns contiguous elements on head dimention. And we will
// use the distribution to enable/disable threads in order to override knew_tile content
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) {}
// use the distribution to enable/disable threads in order to override q_tile content
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
{
auto rotary_cos_tile = load_tile(rotary_cos_window);
auto rotary_sin_tile = load_tile(rotary_sin_window);
constexpr index_t KPerThread = 16 / sizeof(QDataType);
static_assert(kTileSizeD % KPerThread == 0);
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
if((start_x + KPerThread) <= rotary_dim)
{
constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size();
static_assert(thread_buffer_size % KPerThread == 0);
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
const auto left = type_convert<float>(q_tile.thread_buf_[idx]);
const auto right = type_convert<float>(q_tile.thread_buf_[idx + 1]);
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx / 2]);
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx / 2]);
q_tile.thread_buf_[idx] = type_convert<KDataType>(left * cos - right * sin);
q_tile.thread_buf_[idx + 1] =
type_convert<KDataType>(right * cos + left * sin);
});
}
}
else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED
{
}
constexpr index_t KPerThread = 8 / sizeof(QDataType);
static_assert(kTileSizeD % KPerThread == 0);
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
if((start_x + KPerThread) <= rotary_dim)
{
const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
auto q_other_window = q_window;
move_tile_window(q_other_window,
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
auto q_other_tile = load_tile(q_other_window);
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
auto rotary_cos_tile = load_tile(rotary_cos_window);
move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
auto rotary_sin_tile = load_tile(rotary_sin_window);
constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size();
static_assert(thread_buffer_size % KPerThread == 0);
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
const auto curr = type_convert<float>(q_tile.thread_buf_[idx]);
const auto other = type_convert<float>(q_other_tile.thread_buf_[idx]);
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx]);
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx]);
q_tile.thread_buf_[idx] =
type_convert<KDataType>(curr * cos + other * (is_left ? -sin : sin));
});
}
}
print_tile(q_tile, 8);
store_tile(q_dram_block_window, q_tile);
}
}
@@ -292,16 +366,20 @@ struct BlockFmhaFwdAppendKVPipeline
typename KnewDramBlockWindow,
typename VDramBlockWindow,
typename VnewDramBlockWindow,
typename RotaryCosDramBlockWindow,
typename RotarySinDramBlockWindow>
typename QRotaryCosDramBlockWindow,
typename QRotarySinDramBlockWindow,
typename KnewRotaryCosDramBlockWindow,
typename KnewRotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE auto
operator()(QDramBlockWindow& q_dram_block_window,
KDramBlockWindow& k_dram_block_window,
const KnewDramBlockWindow& knew_dram_block_window,
VDramBlockWindow& v_dram_block_window,
const VnewDramBlockWindow& vnew_dram_block_window,
const RotaryCosDramBlockWindow& rotary_cos_dram_block_window,
const RotarySinDramBlockWindow& rotary_sin_dram_block_window,
const QRotaryCosDramBlockWindow& q_rotary_cos_dram_block_window,
const QRotarySinDramBlockWindow& q_rotary_sin_dram_block_window,
const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window,
const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window,
void* smem_ptr,
index_t rotary_dim = 0) const
{
@@ -313,8 +391,10 @@ struct BlockFmhaFwdAppendKVPipeline
v_dram_block_window,
vnew_dram_block_window,
identity{},
rotary_cos_dram_block_window,
rotary_sin_dram_block_window,
q_rotary_cos_dram_block_window,
q_rotary_sin_dram_block_window,
knew_rotary_cos_dram_block_window,
knew_rotary_sin_dram_block_window,
smem_ptr,
rotary_dim);
}