From e88253a2f47771246a5494104f075284dadf6d40 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 23 Jul 2024 03:28:40 +0000 Subject: [PATCH] Add code blocks for q_tile --- .../fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index d200d457b3..416f321487 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -165,6 +165,7 @@ struct BlockFmhaFwdAppendKVPipeline return tile_elementwise_in(knew_element_func, knew); }(); + // optionally apply rotary embedding to Knew if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE) { auto rotary_cos_window = @@ -218,12 +219,11 @@ struct BlockFmhaFwdAppendKVPipeline if((start_x + KPerThread) <= rotary_dim) { - bool is_left = (start_x + KPerThread) <= (rotary_dim / 2); + const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2); auto knew_other_window = knew_window; move_tile_window(knew_other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); - auto knew_other_tile = load_tile(knew_other_window); move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)}); @@ -247,7 +247,7 @@ struct BlockFmhaFwdAppendKVPipeline } } } - print_tile(knew_tile, 7); + // print_tile(knew_tile, 7); store_tile(k_dram_block_window, knew_tile); auto vnew_window = @@ -262,6 +262,7 @@ struct BlockFmhaFwdAppendKVPipeline }(); store_tile(v_dram_block_window, vnew_tile); + // optionally apply rotary embedding to Q if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE) { auto q_window = make_tile_window(q_dram_block_window.get_bottom_tensor_view(), @@ -273,7 +274,8 @@ struct BlockFmhaFwdAppendKVPipeline auto q = load_tile(q_window); return tile_elementwise_in(q_element_func, q); }(); - + print_tile(q_tile, 8); + /// TODO: add rotary_cos/rotary_sin windows for Q (tile size: M0xK0) // We assume that each thread owns contiguous elements on head dimention. And we will // use the distribution to enable/disable threads in order to override knew_tile content if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) {}