Only apply interleaved RoPE on Knew for now

This commit is contained in:
PoYen, Chen
2024-07-18 19:42:14 +00:00
parent 85bfed07fa
commit 23450526c0
3 changed files with 161 additions and 6 deletions

View File

@@ -542,6 +542,38 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto [rotary_cos_host, rotary_sin_host] =
generate_rotary_cos_sin<KDataType>(shape_seqlen_k, rotary_dim, seed);
HOST_DEBUG_STMTS
{
#if 0
printf("rotary_cos's shape: (%2zu, %2zu)\n",
rotary_cos_host.get_length(0),
rotary_cos_host.get_length(1));
for(size_t row = 0; row < rotary_cos_host.get_length(0); ++row)
{
printf("[HOST] rotary_cos[%3zu] = ", row);
for(size_t col = 0; col < rotary_cos_host.get_length(1); ++col)
{
printf("%11.7f", ck_tile::type_convert<float>(rotary_cos_host(row, col)));
}
printf("\n");
}
#endif
#if 1
printf("rotary_sin's shape: (%2zu, %2zu)\n",
rotary_sin_host.get_length(0),
rotary_sin_host.get_length(1));
for(size_t row = 0; row < rotary_sin_host.get_length(0); ++row)
{
printf("[HOST] rotary_sin[%3zu] = ", row);
for(size_t col = 0; col < rotary_sin_host.get_length(1); ++col)
{
printf("%11.7f", ck_tile::type_convert<float>(rotary_sin_host(row, col)));
}
printf("\n");
}
#endif
}
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
@@ -580,11 +612,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(knew_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(vnew_host);
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
std::fill(knew_host.begin(), knew_host.end(), static_cast<KDataType>(99.f));
std::fill(vnew_host.begin(), vnew_host.end(), static_cast<VDataType>(99.f));
}
else if(init_method == "nf")
{
@@ -1057,7 +1088,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
// optionally apply RoPE to the q_host_ref
if(0 < rotary_dim)
if(false && 0 < rotary_dim)
{
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
@@ -1073,6 +1104,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// copy Knew to the end of K
if(0 < seqlen_knew)
{
printf("\n");
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, seqlen_knew, hdim_q});
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[0] / nr, i[1], i[2]); });
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[1], i[0] / nr, i[2]); });
@@ -1094,6 +1127,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
real_knew_host_ref = &knew_host_ref_ro.value();
}
HOST_DEBUG_STMTS {
for(size_t row = 0; row < real_knew_host_ref->get_length(1); ++row)
{
printf("[HOST] real_knew_host_ref[%3zu] = ", row);
for(size_t col = 0; col < real_knew_host_ref->get_length(2); ++col)
{
printf("%11.7f",
ck_tile::type_convert<float>((*real_knew_host_ref)(0, row, col)));
}
printf("\n");
}
}
const std::size_t knew_start = real_seqlen_k - seqlen_knew;
k_host_ref.ForEach([&](auto& self, auto i) {
if(knew_start <= i[1])

View File

@@ -104,6 +104,13 @@ struct BlockFmhaFwdAppendKVPipeline
index_t rotary_dim = 0,
bool is_rotary_interleaved = false) const
{
auto* const ksmem = reinterpret_cast<KDataType*>(smem_ptr);
if(threadIdx.x == 0)
{
printf("\n");
}
(void)q_dram_block_window_tmp;
(void)q_element_func;
(void)k_dram_block_window_tmp;
@@ -132,7 +139,109 @@ struct BlockFmhaFwdAppendKVPipeline
Policy::template MakeKnewDramTileDistribution<Problem>());
auto knew_tile = load_tile(knew_dram_window);
/// TODO: apply RoPE on knew_tile here
if constexpr(kApplyRoPE)
{
auto rotary_cos_window = make_tile_window(
rotary_cos_block_window_tmp.get_bottom_tensor_view(),
rotary_cos_block_window_tmp.get_window_lengths(),
rotary_cos_block_window_tmp.get_window_origin(),
Policy::template MakeRotaryCosSinInterleaveDramTileDistribution<Problem>());
auto rotary_sin_window = make_tile_window(
rotary_sin_block_window_tmp.get_bottom_tensor_view(),
rotary_sin_block_window_tmp.get_window_lengths(),
rotary_sin_block_window_tmp.get_window_origin(),
Policy::template MakeRotaryCosSinInterleaveDramTileDistribution<Problem>());
auto rotary_cos_tile = load_tile(rotary_cos_window);
auto rotary_sin_tile = load_tile(rotary_sin_window);
#if 0
auto rotary_tile = rotary_sin_tile;
constexpr auto spans = decltype(rotary_tile)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
rotary_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = tile_idx.at(number<0>{});
const auto col = tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ksmem[row * (kTileSizeD / 2) + col] = rotary_tile(i_j_idx);
});
});
block_sync_lds();
DEVICE_DEBUG_STMTS
{
for(int row = 0; row < kTileSizeSk; ++row)
{
printf("[DEVICE] rotary_smem[%3d] = ", row);
for(int col = 0; col < rotary_dim / 2; ++col)
{
printf("%11.7f", type_convert<float>(ksmem[row * (kTileSizeD / 2) + col]));
}
printf("\n");
}
}
#endif
constexpr index_t KPerThread = 16 / sizeof(KDataType);
static_assert(kTileSizeD % KPerThread == 0);
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
index_t start_x = (threadIdx.x % KThreadPerBlock);
if(start_x + KPerThread <= rotary_dim)
{
constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size();
static_assert(thread_buffer_size % KPerThread == 0);
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
const auto left = type_convert<float>(knew_tile.thread_buf_[idx]);
const auto right = type_convert<float>(knew_tile.thread_buf_[idx + 1]);
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx / 2]);
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx / 2]);
knew_tile.thread_buf_[idx] = left * cos - right * sin;
knew_tile.thread_buf_[idx + 1] = right * cos + left * sin;
});
}
#if 0
DEVICE_DEBUG_STMTS { printf("[DEVICE] kTileSizeD: %3d\n", kTileSizeD); }
{
constexpr auto spans = decltype(knew_tile)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
knew_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = tile_idx.at(number<0>{});
const auto col = tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ksmem[row * kTileSizeD + col] = knew_tile(i_j_idx);
});
});
}
block_sync_lds();
DEVICE_DEBUG_STMTS
{
for(int row = 0; row < 7; ++row)
{
printf("[DEVICE] knew_tile[%3d] = ", row);
for(int col = 0; col < kTileSizeD; ++col)
{
printf("%11.7f", type_convert<float>(ksmem[row * kTileSizeD + col]));
}
printf("\n");
}
}
#endif
}
store_tile(k_dram_block_window_tmp, knew_tile);
auto vnew_dram_block_window =

View File

@@ -54,7 +54,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
return sizeof(KDataType) * Problem::kTileSizeSk * (Problem::kTileSizeD / 2);
return sizeof(KDataType) * Problem::kTileSizeSk * (Problem::kTileSizeD);
}
template <typename Problem>