Only apply interleaved RoPE on Knew for now

This commit is contained in:
PoYen, Chen
2024-07-18 19:42:14 +00:00
parent 85bfed07fa
commit 23450526c0
3 changed files with 161 additions and 6 deletions

View File

@@ -542,6 +542,38 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto [rotary_cos_host, rotary_sin_host] =
generate_rotary_cos_sin<KDataType>(shape_seqlen_k, rotary_dim, seed);
HOST_DEBUG_STMTS
{
#if 0
printf("rotary_cos's shape: (%2zu, %2zu)\n",
rotary_cos_host.get_length(0),
rotary_cos_host.get_length(1));
for(size_t row = 0; row < rotary_cos_host.get_length(0); ++row)
{
printf("[HOST] rotary_cos[%3zu] = ", row);
for(size_t col = 0; col < rotary_cos_host.get_length(1); ++col)
{
printf("%11.7f", ck_tile::type_convert<float>(rotary_cos_host(row, col)));
}
printf("\n");
}
#endif
#if 1
printf("rotary_sin's shape: (%2zu, %2zu)\n",
rotary_sin_host.get_length(0),
rotary_sin_host.get_length(1));
for(size_t row = 0; row < rotary_sin_host.get_length(0); ++row)
{
printf("[HOST] rotary_sin[%3zu] = ", row);
for(size_t col = 0; col < rotary_sin_host.get_length(1); ++col)
{
printf("%11.7f", ck_tile::type_convert<float>(rotary_sin_host(row, col)));
}
printf("\n");
}
#endif
}
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
@@ -580,11 +612,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(knew_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(vnew_host);
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
std::fill(knew_host.begin(), knew_host.end(), static_cast<KDataType>(99.f));
std::fill(vnew_host.begin(), vnew_host.end(), static_cast<VDataType>(99.f));
}
else if(init_method == "nf")
{
@@ -1057,7 +1088,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
// optionally apply RoPE to the q_host_ref
if(0 < rotary_dim)
if(false && 0 < rotary_dim)
{
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
@@ -1073,6 +1104,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// copy Knew to the end of K
if(0 < seqlen_knew)
{
printf("\n");
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, seqlen_knew, hdim_q});
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[0] / nr, i[1], i[2]); });
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[1], i[0] / nr, i[2]); });
@@ -1094,6 +1127,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
real_knew_host_ref = &knew_host_ref_ro.value();
}
HOST_DEBUG_STMTS {
for(size_t row = 0; row < real_knew_host_ref->get_length(1); ++row)
{
printf("[HOST] real_knew_host_ref[%3zu] = ", row);
for(size_t col = 0; col < real_knew_host_ref->get_length(2); ++col)
{
printf("%11.7f",
ck_tile::type_convert<float>((*real_knew_host_ref)(0, row, col)));
}
printf("\n");
}
}
const std::size_t knew_start = real_seqlen_k - seqlen_knew;
k_host_ref.ForEach([&](auto& self, auto i) {
if(knew_start <= i[1])