mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Only apply interleaved RoPE on Knew for now
This commit is contained in:
@@ -542,6 +542,38 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
auto [rotary_cos_host, rotary_sin_host] =
|
||||
generate_rotary_cos_sin<KDataType>(shape_seqlen_k, rotary_dim, seed);
|
||||
|
||||
HOST_DEBUG_STMTS
|
||||
{
|
||||
#if 0
|
||||
printf("rotary_cos's shape: (%2zu, %2zu)\n",
|
||||
rotary_cos_host.get_length(0),
|
||||
rotary_cos_host.get_length(1));
|
||||
for(size_t row = 0; row < rotary_cos_host.get_length(0); ++row)
|
||||
{
|
||||
printf("[HOST] rotary_cos[%3zu] = ", row);
|
||||
for(size_t col = 0; col < rotary_cos_host.get_length(1); ++col)
|
||||
{
|
||||
printf("%11.7f", ck_tile::type_convert<float>(rotary_cos_host(row, col)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
#if 1
|
||||
printf("rotary_sin's shape: (%2zu, %2zu)\n",
|
||||
rotary_sin_host.get_length(0),
|
||||
rotary_sin_host.get_length(1));
|
||||
for(size_t row = 0; row < rotary_sin_host.get_length(0); ++row)
|
||||
{
|
||||
printf("[HOST] rotary_sin[%3zu] = ", row);
|
||||
for(size_t col = 0; col < rotary_sin_host.get_length(1); ++col)
|
||||
{
|
||||
printf("%11.7f", ck_tile::type_convert<float>(rotary_sin_host(row, col)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<LSEDataType> lse_acc_host(
|
||||
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
@@ -580,11 +612,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(knew_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
|
||||
std::fill(knew_host.begin(), knew_host.end(), static_cast<KDataType>(99.f));
|
||||
std::fill(vnew_host.begin(), vnew_host.end(), static_cast<VDataType>(99.f));
|
||||
}
|
||||
else if(init_method == "nf")
|
||||
{
|
||||
@@ -1057,7 +1088,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
|
||||
|
||||
// optionally apply RoPE to the q_host_ref
|
||||
if(0 < rotary_dim)
|
||||
if(false && 0 < rotary_dim)
|
||||
{
|
||||
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
|
||||
|
||||
@@ -1073,6 +1104,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// copy Knew to the end of K
|
||||
if(0 < seqlen_knew)
|
||||
{
|
||||
printf("\n");
|
||||
|
||||
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, seqlen_knew, hdim_q});
|
||||
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[0] / nr, i[1], i[2]); });
|
||||
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[1], i[0] / nr, i[2]); });
|
||||
@@ -1094,6 +1127,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
real_knew_host_ref = &knew_host_ref_ro.value();
|
||||
}
|
||||
|
||||
HOST_DEBUG_STMTS {
|
||||
for(size_t row = 0; row < real_knew_host_ref->get_length(1); ++row)
|
||||
{
|
||||
printf("[HOST] real_knew_host_ref[%3zu] = ", row);
|
||||
for(size_t col = 0; col < real_knew_host_ref->get_length(2); ++col)
|
||||
{
|
||||
printf("%11.7f",
|
||||
ck_tile::type_convert<float>((*real_knew_host_ref)(0, row, col)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
const std::size_t knew_start = real_seqlen_k - seqlen_knew;
|
||||
k_host_ref.ForEach([&](auto& self, auto i) {
|
||||
if(knew_start <= i[1])
|
||||
|
||||
Reference in New Issue
Block a user