Add code blocks for q_tile

This commit is contained in:
PoYen, Chen
2024-07-23 03:28:40 +00:00
parent 1dbed18555
commit e88253a2f4

View File

@@ -165,6 +165,7 @@ struct BlockFmhaFwdAppendKVPipeline
return tile_elementwise_in(knew_element_func, knew);
}();
// optionally apply rotary embedding to Knew
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
{
auto rotary_cos_window =
@@ -218,12 +219,11 @@ struct BlockFmhaFwdAppendKVPipeline
if((start_x + KPerThread) <= rotary_dim)
{
bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
auto knew_other_window = knew_window;
move_tile_window(knew_other_window,
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
auto knew_other_tile = load_tile(knew_other_window);
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
@@ -247,7 +247,7 @@ struct BlockFmhaFwdAppendKVPipeline
}
}
}
print_tile(knew_tile, 7);
// print_tile(knew_tile, 7);
store_tile(k_dram_block_window, knew_tile);
auto vnew_window =
@@ -262,6 +262,7 @@ struct BlockFmhaFwdAppendKVPipeline
}();
store_tile(v_dram_block_window, vnew_tile);
// optionally apply rotary embedding to Q
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
{
auto q_window = make_tile_window(q_dram_block_window.get_bottom_tensor_view(),
@@ -273,7 +274,8 @@ struct BlockFmhaFwdAppendKVPipeline
auto q = load_tile(q_window);
return tile_elementwise_in(q_element_func, q);
}();
print_tile(q_tile, 8);
/// TODO: add rotary_cos/rotary_sin windows for Q (tile size: M0xK0)
// We assume that each thread owns contiguous elements on head dimention. And we will
// use the distribution to enable/disable threads in order to override knew_tile content
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) {}