mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Only apply interleaved RoPE on Knew for now
This commit is contained in:
@@ -542,6 +542,38 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
auto [rotary_cos_host, rotary_sin_host] =
|
||||
generate_rotary_cos_sin<KDataType>(shape_seqlen_k, rotary_dim, seed);
|
||||
|
||||
HOST_DEBUG_STMTS
|
||||
{
|
||||
#if 0
|
||||
printf("rotary_cos's shape: (%2zu, %2zu)\n",
|
||||
rotary_cos_host.get_length(0),
|
||||
rotary_cos_host.get_length(1));
|
||||
for(size_t row = 0; row < rotary_cos_host.get_length(0); ++row)
|
||||
{
|
||||
printf("[HOST] rotary_cos[%3zu] = ", row);
|
||||
for(size_t col = 0; col < rotary_cos_host.get_length(1); ++col)
|
||||
{
|
||||
printf("%11.7f", ck_tile::type_convert<float>(rotary_cos_host(row, col)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
#if 1
|
||||
printf("rotary_sin's shape: (%2zu, %2zu)\n",
|
||||
rotary_sin_host.get_length(0),
|
||||
rotary_sin_host.get_length(1));
|
||||
for(size_t row = 0; row < rotary_sin_host.get_length(0); ++row)
|
||||
{
|
||||
printf("[HOST] rotary_sin[%3zu] = ", row);
|
||||
for(size_t col = 0; col < rotary_sin_host.get_length(1); ++col)
|
||||
{
|
||||
printf("%11.7f", ck_tile::type_convert<float>(rotary_sin_host(row, col)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<LSEDataType> lse_acc_host(
|
||||
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
@@ -580,11 +612,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(knew_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
|
||||
std::fill(knew_host.begin(), knew_host.end(), static_cast<KDataType>(99.f));
|
||||
std::fill(vnew_host.begin(), vnew_host.end(), static_cast<VDataType>(99.f));
|
||||
}
|
||||
else if(init_method == "nf")
|
||||
{
|
||||
@@ -1057,7 +1088,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
|
||||
|
||||
// optionally apply RoPE to the q_host_ref
|
||||
if(0 < rotary_dim)
|
||||
if(false && 0 < rotary_dim)
|
||||
{
|
||||
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
|
||||
|
||||
@@ -1073,6 +1104,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// copy Knew to the end of K
|
||||
if(0 < seqlen_knew)
|
||||
{
|
||||
printf("\n");
|
||||
|
||||
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, seqlen_knew, hdim_q});
|
||||
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[0] / nr, i[1], i[2]); });
|
||||
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[1], i[0] / nr, i[2]); });
|
||||
@@ -1094,6 +1127,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
real_knew_host_ref = &knew_host_ref_ro.value();
|
||||
}
|
||||
|
||||
HOST_DEBUG_STMTS {
|
||||
for(size_t row = 0; row < real_knew_host_ref->get_length(1); ++row)
|
||||
{
|
||||
printf("[HOST] real_knew_host_ref[%3zu] = ", row);
|
||||
for(size_t col = 0; col < real_knew_host_ref->get_length(2); ++col)
|
||||
{
|
||||
printf("%11.7f",
|
||||
ck_tile::type_convert<float>((*real_knew_host_ref)(0, row, col)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
const std::size_t knew_start = real_seqlen_k - seqlen_knew;
|
||||
k_host_ref.ForEach([&](auto& self, auto i) {
|
||||
if(knew_start <= i[1])
|
||||
|
||||
@@ -104,6 +104,13 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
index_t rotary_dim = 0,
|
||||
bool is_rotary_interleaved = false) const
|
||||
{
|
||||
|
||||
auto* const ksmem = reinterpret_cast<KDataType*>(smem_ptr);
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
(void)q_dram_block_window_tmp;
|
||||
(void)q_element_func;
|
||||
(void)k_dram_block_window_tmp;
|
||||
@@ -132,7 +139,109 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
Policy::template MakeKnewDramTileDistribution<Problem>());
|
||||
|
||||
auto knew_tile = load_tile(knew_dram_window);
|
||||
/// TODO: apply RoPE on knew_tile here
|
||||
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
auto rotary_cos_window = make_tile_window(
|
||||
rotary_cos_block_window_tmp.get_bottom_tensor_view(),
|
||||
rotary_cos_block_window_tmp.get_window_lengths(),
|
||||
rotary_cos_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeRotaryCosSinInterleaveDramTileDistribution<Problem>());
|
||||
|
||||
auto rotary_sin_window = make_tile_window(
|
||||
rotary_sin_block_window_tmp.get_bottom_tensor_view(),
|
||||
rotary_sin_block_window_tmp.get_window_lengths(),
|
||||
rotary_sin_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeRotaryCosSinInterleaveDramTileDistribution<Problem>());
|
||||
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
#if 0
|
||||
auto rotary_tile = rotary_sin_tile;
|
||||
constexpr auto spans = decltype(rotary_tile)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
rotary_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = tile_idx.at(number<0>{});
|
||||
const auto col = tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
ksmem[row * (kTileSizeD / 2) + col] = rotary_tile(i_j_idx);
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
DEVICE_DEBUG_STMTS
|
||||
{
|
||||
for(int row = 0; row < kTileSizeSk; ++row)
|
||||
{
|
||||
printf("[DEVICE] rotary_smem[%3d] = ", row);
|
||||
for(int col = 0; col < rotary_dim / 2; ++col)
|
||||
{
|
||||
printf("%11.7f", type_convert<float>(ksmem[row * (kTileSizeD / 2) + col]));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
constexpr index_t KPerThread = 16 / sizeof(KDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock);
|
||||
if(start_x + KPerThread <= rotary_dim)
|
||||
{
|
||||
constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size();
|
||||
static_assert(thread_buffer_size % KPerThread == 0);
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
const auto left = type_convert<float>(knew_tile.thread_buf_[idx]);
|
||||
const auto right = type_convert<float>(knew_tile.thread_buf_[idx + 1]);
|
||||
|
||||
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx / 2]);
|
||||
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx / 2]);
|
||||
|
||||
knew_tile.thread_buf_[idx] = left * cos - right * sin;
|
||||
knew_tile.thread_buf_[idx + 1] = right * cos + left * sin;
|
||||
});
|
||||
}
|
||||
#if 0
|
||||
DEVICE_DEBUG_STMTS { printf("[DEVICE] kTileSizeD: %3d\n", kTileSizeD); }
|
||||
|
||||
{
|
||||
constexpr auto spans = decltype(knew_tile)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
knew_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = tile_idx.at(number<0>{});
|
||||
const auto col = tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
ksmem[row * kTileSizeD + col] = knew_tile(i_j_idx);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
DEVICE_DEBUG_STMTS
|
||||
{
|
||||
for(int row = 0; row < 7; ++row)
|
||||
{
|
||||
printf("[DEVICE] knew_tile[%3d] = ", row);
|
||||
for(int col = 0; col < kTileSizeD; ++col)
|
||||
{
|
||||
printf("%11.7f", type_convert<float>(ksmem[row * kTileSizeD + col]));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
store_tile(k_dram_block_window_tmp, knew_tile);
|
||||
|
||||
auto vnew_dram_block_window =
|
||||
|
||||
@@ -54,7 +54,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
{
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
|
||||
return sizeof(KDataType) * Problem::kTileSizeSk * (Problem::kTileSizeD / 2);
|
||||
return sizeof(KDataType) * Problem::kTileSizeSk * (Problem::kTileSizeD);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
Reference in New Issue
Block a user