mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Add code blocks for q_tile
This commit is contained in:
@@ -165,6 +165,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
return tile_elementwise_in(knew_element_func, knew);
|
||||
}();
|
||||
|
||||
// optionally apply rotary embedding to Knew
|
||||
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
|
||||
{
|
||||
auto rotary_cos_window =
|
||||
@@ -218,12 +219,11 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
{
|
||||
bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
|
||||
const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
|
||||
|
||||
auto knew_other_window = knew_window;
|
||||
move_tile_window(knew_other_window,
|
||||
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
|
||||
|
||||
auto knew_other_tile = load_tile(knew_other_window);
|
||||
|
||||
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
@@ -247,7 +247,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
}
|
||||
}
|
||||
}
|
||||
print_tile(knew_tile, 7);
|
||||
// print_tile(knew_tile, 7);
|
||||
store_tile(k_dram_block_window, knew_tile);
|
||||
|
||||
auto vnew_window =
|
||||
@@ -262,6 +262,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
}();
|
||||
store_tile(v_dram_block_window, vnew_tile);
|
||||
|
||||
// optionally apply rotary embedding to Q
|
||||
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
|
||||
{
|
||||
auto q_window = make_tile_window(q_dram_block_window.get_bottom_tensor_view(),
|
||||
@@ -273,7 +274,8 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
auto q = load_tile(q_window);
|
||||
return tile_elementwise_in(q_element_func, q);
|
||||
}();
|
||||
|
||||
print_tile(q_tile, 8);
|
||||
/// TODO: add rotary_cos/rotary_sin windows for Q (tile size: M0xK0)
|
||||
// We assume that each thread owns contiguous elements on head dimention. And we will
|
||||
// use the distribution to enable/disable threads in order to override knew_tile content
|
||||
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) {}
|
||||
|
||||
Reference in New Issue
Block a user