mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Handle cross-page-block write
This commit is contained in:
@@ -403,6 +403,16 @@ struct FmhaFwdAppendKVKernel
|
||||
const index_t num_blocks =
|
||||
integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size);
|
||||
|
||||
DEVICE_DEBUG_STMTS
|
||||
{
|
||||
printf("[DEVICE] block_indics: ");
|
||||
for(index_t i_block = 0; i_block < num_blocks; ++i_block)
|
||||
{
|
||||
printf("(%d, %d) ", i_block, block_indices[i_block]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
const long_index_t fixed_offset =
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_v;
|
||||
@@ -734,25 +744,45 @@ struct FmhaFwdAppendKVKernel
|
||||
|
||||
auto [i_block1, v_dram_window_tmp] = v_tile_navigator.make_tile_window(
|
||||
v_dram_window, {0, skip_append_kv ? 0 : kargs.seqlen_k + i_n0});
|
||||
DEVICE_DEBUG_STMTS
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
printf("[DEVICE] i_block1: %d\n", i_block1);
|
||||
auto local_origin = v_dram_window_tmp.get_window_origin();
|
||||
printf("[DEVICE] origin: (%d, %d)\n",
|
||||
local_origin.at(number<0>{}),
|
||||
local_origin.at(number<1>{}));
|
||||
DEVICE_DEBUG_STMTS
|
||||
{
|
||||
printf("[DEVICE] i_block1: %d\n", i_block1);
|
||||
auto local_origin = v_dram_window_tmp.get_window_origin();
|
||||
printf("[DEVICE] origin: (%d, %d)\n",
|
||||
local_origin.at(number<0>{}),
|
||||
local_origin.at(number<1>{}));
|
||||
|
||||
printf("[DEVICE] psychical block_ptr 0: %p\n",
|
||||
static_cast<void*>(v_tile_navigator.physical_blocks +
|
||||
0 * v_tile_navigator.block_stride));
|
||||
printf("[DEVICE] psychical block_ptr 1: %p\n",
|
||||
static_cast<void*>(v_tile_navigator.physical_blocks +
|
||||
1 * v_tile_navigator.block_stride));
|
||||
|
||||
printf("[DEVICE] tile window data ptr: %p\n",
|
||||
static_cast<void*>(v_dram_window_tmp.get_bottom_tensor_view().buf_.p_data_));
|
||||
}
|
||||
}
|
||||
auto vnew_dram_window =
|
||||
make_tile_window(vnew_dram,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
|
||||
{0, i_n0});
|
||||
|
||||
DEVICE_DEBUG_STMTS
|
||||
{
|
||||
printf("[DEVICE] skip_transform_q: %d, skip_appendkv: %d\n",
|
||||
kargs.seqlen_q <= i_m0,
|
||||
kargs.seqlen_knew <= i_n0);
|
||||
}
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window_tmp,
|
||||
i_block0,
|
||||
knew_dram_window,
|
||||
v_dram_window_tmp,
|
||||
i_block1,
|
||||
vnew_dram_window,
|
||||
q_rotary_cos_dram_window,
|
||||
q_rotary_sin_dram_window,
|
||||
@@ -768,8 +798,10 @@ struct FmhaFwdAppendKVKernel
|
||||
{
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window_tmp,
|
||||
i_block0,
|
||||
knew_dram_window,
|
||||
v_dram_window_tmp,
|
||||
i_block1,
|
||||
vnew_dram_window,
|
||||
q_rotary_cos_dram_window,
|
||||
q_rotary_sin_dram_window,
|
||||
|
||||
@@ -89,10 +89,12 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
KDramBlockWindow& k_dram_block_window, // N0*K0 tile
|
||||
KDramBlockWindow& k_dram_block_window, // N0*K0 tile
|
||||
index_t i_block0,
|
||||
const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile
|
||||
const KnewElementFunction& knew_element_func,
|
||||
VDramBlockWindow& v_dram_block_window, // N1*N0 tile
|
||||
VDramBlockWindow& v_dram_block_window, // N1*N0 tile
|
||||
index_t i_block1,
|
||||
const VnewDramBlockWindow& vnew_dram_block_window, // N1*N0 tile
|
||||
const VnewElementFunction& vnew_element_func,
|
||||
const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window,
|
||||
@@ -148,8 +150,14 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
/// TODO: handle cross-page-block write
|
||||
store_tile(k_dram_block_window, knew_tile);
|
||||
|
||||
// write tile to another block if nesscary
|
||||
if(k_tile_navigator.is_closs_block(k_dram_block_window))
|
||||
{
|
||||
k_tile_navigator.move_to_block(i_block0, k_dram_block_window, i_block0 + 1);
|
||||
store_tile(k_dram_block_window, knew_tile);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -167,8 +175,14 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
/// TODO: handle cross-page-block write
|
||||
store_tile(v_dram_block_window, vnew_tile);
|
||||
|
||||
// write tile to another block if nesscary
|
||||
if(v_tile_navigator.is_closs_block(v_dram_block_window))
|
||||
{
|
||||
v_tile_navigator.move_to_block(i_block1, v_dram_block_window, i_block1 + 1);
|
||||
store_tile(v_dram_block_window, vnew_tile);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -229,8 +243,10 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(QDramBlockWindow& q_dram_block_window,
|
||||
KDramBlockWindow& k_dram_block_window,
|
||||
index_t i_block0,
|
||||
const KnewDramBlockWindow& knew_dram_block_window,
|
||||
VDramBlockWindow& v_dram_block_window,
|
||||
index_t i_block1,
|
||||
const VnewDramBlockWindow& vnew_dram_block_window,
|
||||
const QRotaryCosDramBlockWindow& q_rotary_cos_dram_block_window,
|
||||
const QRotarySinDramBlockWindow& q_rotary_sin_dram_block_window,
|
||||
@@ -245,9 +261,11 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
return operator()(q_dram_block_window,
|
||||
identity{},
|
||||
k_dram_block_window,
|
||||
i_block0,
|
||||
knew_dram_block_window,
|
||||
identity{},
|
||||
v_dram_block_window,
|
||||
i_block1,
|
||||
vnew_dram_block_window,
|
||||
identity{},
|
||||
q_rotary_cos_dram_block_window,
|
||||
|
||||
Reference in New Issue
Block a user