mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Fix wrong rotary_cos/rotary_sin memory size for Q
This commit is contained in:
@@ -1106,9 +1106,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::reference_batched_rotary_position_embedding<RoPEComputeDataType>(
|
||||
q_host_ref, rotary_cos_host, rotary_sin_host, is_rotary_interleaved, q_host_ref_ro);
|
||||
|
||||
#if 0
|
||||
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
|
||||
#endif
|
||||
}
|
||||
#if 0
|
||||
HOST_DEBUG_STMTS {
|
||||
|
||||
@@ -494,7 +494,7 @@ struct FmhaFwdAppendKVKernel
|
||||
const auto rotary_cos_dram_native =
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const KDataType*>(kargs.rotary_cos_ptr),
|
||||
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
|
||||
make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
|
||||
make_tuple(kargs.rotary_dim / 2, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
@@ -519,7 +519,7 @@ struct FmhaFwdAppendKVKernel
|
||||
const auto rotary_sin_dram_native =
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const KDataType*>(kargs.rotary_sin_ptr),
|
||||
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
|
||||
make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
|
||||
make_tuple(kargs.rotary_dim / 2, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -108,10 +108,6 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
{
|
||||
#if defined(ENABLE_DEVICE_DEBUG_STMTS)
|
||||
auto* const ksmem = reinterpret_cast<KDataType*>(smem_ptr);
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
auto print_tile = [&](const auto& tile, index_t num_display_rows = -1) {
|
||||
@@ -140,6 +136,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
|
||||
DEVICE_DEBUG_STMTS
|
||||
{
|
||||
printf("\n");
|
||||
for(int row = 0;
|
||||
row < (0 < num_display_rows ? std::min(num_display_rows, num_rows) : num_rows);
|
||||
++row)
|
||||
@@ -195,7 +192,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
rotary_dim,
|
||||
thread_end);
|
||||
}
|
||||
print_tile(knew_tile, 2);
|
||||
// print_tile(knew_tile, 2);
|
||||
store_tile(k_dram_block_window, knew_tile);
|
||||
|
||||
auto vnew_window = make_tile_window(
|
||||
|
||||
Reference in New Issue
Block a user