Fix wrong rotary_cos/rotary_sin memory size for Q

This commit is contained in:
PoYen, Chen
2024-07-23 16:22:25 +00:00
parent 85bac93951
commit eb4ea3ac2a
3 changed files with 4 additions and 9 deletions

View File

@@ -1106,9 +1106,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::reference_batched_rotary_position_embedding<RoPEComputeDataType>(
q_host_ref, rotary_cos_host, rotary_sin_host, is_rotary_interleaved, q_host_ref_ro);
#if 0
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
#endif
}
#if 0
HOST_DEBUG_STMTS {

View File

@@ -494,7 +494,7 @@ struct FmhaFwdAppendKVKernel
const auto rotary_cos_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const KDataType*>(kargs.rotary_cos_ptr),
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
make_tuple(kargs.rotary_dim / 2, 1),
number<8>{},
number<1>{});
@@ -519,7 +519,7 @@ struct FmhaFwdAppendKVKernel
const auto rotary_sin_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const KDataType*>(kargs.rotary_sin_ptr),
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2),
make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
make_tuple(kargs.rotary_dim / 2, 1),
number<8>{},
number<1>{});

View File

@@ -108,10 +108,6 @@ struct BlockFmhaFwdAppendKVPipeline
{
#if defined(ENABLE_DEVICE_DEBUG_STMTS)
auto* const ksmem = reinterpret_cast<KDataType*>(smem_ptr);
if(threadIdx.x == 0)
{
printf("\n");
}
#endif
auto print_tile = [&](const auto& tile, index_t num_display_rows = -1) {
@@ -140,6 +136,7 @@ struct BlockFmhaFwdAppendKVPipeline
DEVICE_DEBUG_STMTS
{
printf("\n");
for(int row = 0;
row < (0 < num_display_rows ? std::min(num_display_rows, num_rows) : num_rows);
++row)
@@ -195,7 +192,7 @@ struct BlockFmhaFwdAppendKVPipeline
rotary_dim,
thread_end);
}
print_tile(knew_tile, 2);
// print_tile(knew_tile, 2);
store_tile(k_dram_block_window, knew_tile);
auto vnew_window = make_tile_window(