From 0e5cb6f913c71e614d0ad938903cdc46d6169dd1 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 23 Jul 2024 06:53:24 +0000 Subject: [PATCH] Skip code if # of block is more than needed --- .../block_fmha_fwd_appendkv_pipeline.hpp | 376 +++++++++--------- 1 file changed, 191 insertions(+), 185 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index 2706c9002a..d8afe24eeb 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -101,6 +101,8 @@ struct BlockFmhaFwdAppendKVPipeline const QRotarySinDramBlockWindow q_rotary_sin_dram_block_window, const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window, const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window, + bool skip_q, + bool skip_kv, void* smem_ptr, index_t rotary_dim = 0) const { @@ -158,206 +160,206 @@ struct BlockFmhaFwdAppendKVPipeline #endif }; - auto knew_window = - make_tile_window(knew_dram_block_window.get_bottom_tensor_view(), - knew_dram_block_window.get_window_lengths(), - knew_dram_block_window.get_window_origin(), - Policy::template MakeKnewDramTileDistribution()); - - auto knew_tile = [&]() { - auto knew = load_tile(knew_window); - return tile_elementwise_in(knew_element_func, knew); - }(); - - // optionally apply rotary embedding to Knew - if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE) + if(!skip_kv) { - auto rotary_cos_window = - make_tile_window(knew_rotary_cos_dram_block_window.get_bottom_tensor_view(), - knew_rotary_cos_dram_block_window.get_window_lengths(), - knew_rotary_cos_dram_block_window.get_window_origin(), - Policy::template MakeRotaryCosSinTileDistribution()); + auto knew_window = make_tile_window( + knew_dram_block_window, Policy::template MakeKnewDramTileDistribution()); - auto rotary_sin_window = - make_tile_window(knew_rotary_sin_dram_block_window.get_bottom_tensor_view(), - knew_rotary_sin_dram_block_window.get_window_lengths(), - knew_rotary_sin_dram_block_window.get_window_origin(), - Policy::template MakeRotaryCosSinTileDistribution()); - - // We assume that each thread owns contiguous elements on head dimention. And we will - // use the distribution to enable/disable threads in order to override knew_tile content - if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) - { - auto rotary_cos_tile = load_tile(rotary_cos_window); - auto rotary_sin_tile = load_tile(rotary_sin_window); - - constexpr index_t KPerThread = 16 / sizeof(KDataType); - static_assert(kTileSizeD % KPerThread == 0); - constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; - index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; - - if((start_x + KPerThread) <= rotary_dim) - { - constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size(); - static_assert(thread_buffer_size % KPerThread == 0); - static_for<0, thread_buffer_size, 2>{}([&](auto idx) { - const auto left = type_convert(knew_tile.thread_buf_[idx]); - const auto right = type_convert(knew_tile.thread_buf_[idx + 1]); - - const auto cos = type_convert(rotary_cos_tile.thread_buf_[idx / 2]); - const auto sin = type_convert(rotary_sin_tile.thread_buf_[idx / 2]); - - knew_tile.thread_buf_[idx] = - type_convert(left * cos - right * sin); - knew_tile.thread_buf_[idx + 1] = - type_convert(right * cos + left * sin); - }); - } - } - else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED - { - constexpr index_t KPerThread = 8 / sizeof(KDataType); - static_assert(kTileSizeD % KPerThread == 0); - constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; - index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; - - if((start_x + KPerThread) <= rotary_dim) - { - const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2); - - auto knew_other_window = knew_window; - move_tile_window(knew_other_window, - {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); - auto knew_other_tile = load_tile(knew_other_window); - - move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)}); - auto rotary_cos_tile = load_tile(rotary_cos_window); - - move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)}); - auto rotary_sin_tile = load_tile(rotary_sin_window); - - constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size(); - static_assert(thread_buffer_size % KPerThread == 0); - static_for<0, thread_buffer_size, 1>{}([&](auto idx) { - const auto curr = type_convert(knew_tile.thread_buf_[idx]); - const auto other = type_convert(knew_other_tile.thread_buf_[idx]); - - const auto cos = type_convert(rotary_cos_tile.thread_buf_[idx]); - const auto sin = type_convert(rotary_sin_tile.thread_buf_[idx]); - - knew_tile.thread_buf_[idx] = - type_convert(curr * cos + other * (is_left ? -sin : sin)); - }); - } - } - } - // print_tile(knew_tile, 7); - store_tile(k_dram_block_window, knew_tile); - - auto vnew_window = - make_tile_window(vnew_dram_block_window.get_bottom_tensor_view(), - vnew_dram_block_window.get_window_lengths(), - vnew_dram_block_window.get_window_origin(), - Policy::template MakeVnewDramTileDistribution()); - - auto vnew_tile = [&]() { - auto vnew = load_tile(vnew_window); - return tile_elementwise_in(vnew_element_func, vnew); - }(); - store_tile(v_dram_block_window, vnew_tile); - - // optionally apply rotary embedding to Q - if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE) - { - auto q_window = make_tile_window(q_dram_block_window.get_bottom_tensor_view(), - q_dram_block_window.get_window_lengths(), - q_dram_block_window.get_window_origin(), - Policy::template MakeQDramTileDistribution()); - - auto q_tile = [&]() { - auto q = load_tile(q_window); - return tile_elementwise_in(q_element_func, q); + auto knew_tile = [&]() { + auto knew = load_tile(knew_window); + return tile_elementwise_in(knew_element_func, knew); }(); - auto rotary_cos_window = - make_tile_window(q_rotary_cos_dram_block_window.get_bottom_tensor_view(), - q_rotary_cos_dram_block_window.get_window_lengths(), - q_rotary_cos_dram_block_window.get_window_origin(), - Policy::template MakeRotaryCosSinTileDistribution()); - - auto rotary_sin_window = - make_tile_window(q_rotary_sin_dram_block_window.get_bottom_tensor_view(), - q_rotary_sin_dram_block_window.get_window_lengths(), - q_rotary_sin_dram_block_window.get_window_origin(), - Policy::template MakeRotaryCosSinTileDistribution()); - - // We assume that each thread owns contiguous elements on head dimention. And we will - // use the distribution to enable/disable threads in order to override q_tile content - if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) + // optionally apply rotary embedding to Knew + if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE) { - auto rotary_cos_tile = load_tile(rotary_cos_window); - auto rotary_sin_tile = load_tile(rotary_sin_window); + auto rotary_cos_window = + make_tile_window(knew_rotary_cos_dram_block_window, + Policy::template MakeRotaryCosSinTileDistribution()); - constexpr index_t KPerThread = 16 / sizeof(QDataType); - static_assert(kTileSizeD % KPerThread == 0); - constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; - index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; + auto rotary_sin_window = + make_tile_window(knew_rotary_sin_dram_block_window, + Policy::template MakeRotaryCosSinTileDistribution()); - if((start_x + KPerThread) <= rotary_dim) + // We assume that each thread owns contiguous elements on head dimention. And we + // will use the distribution to enable/disable threads in order to override + // knew_tile content + if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) { - constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size(); - static_assert(thread_buffer_size % KPerThread == 0); - static_for<0, thread_buffer_size, 2>{}([&](auto idx) { - const auto left = type_convert(q_tile.thread_buf_[idx]); - const auto right = type_convert(q_tile.thread_buf_[idx + 1]); - - const auto cos = type_convert(rotary_cos_tile.thread_buf_[idx / 2]); - const auto sin = type_convert(rotary_sin_tile.thread_buf_[idx / 2]); - - q_tile.thread_buf_[idx] = type_convert(left * cos - right * sin); - q_tile.thread_buf_[idx + 1] = - type_convert(right * cos + left * sin); - }); - } - } - else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED - { - constexpr index_t KPerThread = 8 / sizeof(QDataType); - static_assert(kTileSizeD % KPerThread == 0); - constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; - index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; - - if((start_x + KPerThread) <= rotary_dim) - { - const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2); - - auto q_other_window = q_window; - move_tile_window(q_other_window, - {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); - auto q_other_tile = load_tile(q_other_window); - - move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)}); auto rotary_cos_tile = load_tile(rotary_cos_window); - - move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)}); auto rotary_sin_tile = load_tile(rotary_sin_window); - constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size(); - static_assert(thread_buffer_size % KPerThread == 0); - static_for<0, thread_buffer_size, 1>{}([&](auto idx) { - const auto curr = type_convert(q_tile.thread_buf_[idx]); - const auto other = type_convert(q_other_tile.thread_buf_[idx]); + constexpr index_t KPerThread = 16 / sizeof(KDataType); + static_assert(kTileSizeD % KPerThread == 0); + constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; + index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; - const auto cos = type_convert(rotary_cos_tile.thread_buf_[idx]); - const auto sin = type_convert(rotary_sin_tile.thread_buf_[idx]); + if((start_x + KPerThread) <= rotary_dim) + { + constexpr index_t thread_buffer_size = + decltype(knew_tile.thread_buf_)::size(); + static_assert(thread_buffer_size % KPerThread == 0); + static_for<0, thread_buffer_size, 2>{}([&](auto idx) { + const auto left = type_convert(knew_tile.thread_buf_[idx]); + const auto right = type_convert(knew_tile.thread_buf_[idx + 1]); - q_tile.thread_buf_[idx] = - type_convert(curr * cos + other * (is_left ? -sin : sin)); - }); + const auto cos = + type_convert(rotary_cos_tile.thread_buf_[idx / 2]); + const auto sin = + type_convert(rotary_sin_tile.thread_buf_[idx / 2]); + + knew_tile.thread_buf_[idx] = + type_convert(left * cos - right * sin); + knew_tile.thread_buf_[idx + 1] = + type_convert(right * cos + left * sin); + }); + } + } + else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED + { + constexpr index_t KPerThread = 8 / sizeof(KDataType); + static_assert(kTileSizeD % KPerThread == 0); + constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; + index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; + + if((start_x + KPerThread) <= rotary_dim) + { + const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2); + + auto knew_other_window = knew_window; + move_tile_window(knew_other_window, + {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); + auto knew_other_tile = load_tile(knew_other_window); + + move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)}); + auto rotary_cos_tile = load_tile(rotary_cos_window); + + move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)}); + auto rotary_sin_tile = load_tile(rotary_sin_window); + + constexpr index_t thread_buffer_size = + decltype(knew_tile.thread_buf_)::size(); + static_assert(thread_buffer_size % KPerThread == 0); + static_for<0, thread_buffer_size, 1>{}([&](auto idx) { + const auto curr = type_convert(knew_tile.thread_buf_[idx]); + const auto other = + type_convert(knew_other_tile.thread_buf_[idx]); + + const auto cos = type_convert(rotary_cos_tile.thread_buf_[idx]); + const auto sin = type_convert(rotary_sin_tile.thread_buf_[idx]); + + knew_tile.thread_buf_[idx] = type_convert( + curr * cos + other * (is_left ? -sin : sin)); + }); + } } } - print_tile(q_tile, 8); - store_tile(q_dram_block_window, q_tile); + print_tile(knew_tile, 2); + store_tile(k_dram_block_window, knew_tile); + + auto vnew_window = make_tile_window( + vnew_dram_block_window, Policy::template MakeVnewDramTileDistribution()); + + auto vnew_tile = [&]() { + auto vnew = load_tile(vnew_window); + return tile_elementwise_in(vnew_element_func, vnew); + }(); + store_tile(v_dram_block_window, vnew_tile); + } + + if(!skip_q) + { + // optionally apply rotary embedding to Q + if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE) + { + auto q_window = make_tile_window( + q_dram_block_window, Policy::template MakeQDramTileDistribution()); + + auto q_tile = [&]() { + auto q = load_tile(q_window); + return tile_elementwise_in(q_element_func, q); + }(); + + auto rotary_cos_window = + make_tile_window(q_rotary_cos_dram_block_window, + Policy::template MakeRotaryCosSinTileDistribution()); + + auto rotary_sin_window = + make_tile_window(q_rotary_sin_dram_block_window, + Policy::template MakeRotaryCosSinTileDistribution()); + + // We assume that each thread owns contiguous elements on head dimention. And we + // will use the distribution to enable/disable threads in order to override q_tile + // content + if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) + { + auto rotary_cos_tile = load_tile(rotary_cos_window); + auto rotary_sin_tile = load_tile(rotary_sin_window); + + constexpr index_t KPerThread = 16 / sizeof(QDataType); + static_assert(kTileSizeD % KPerThread == 0); + constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; + index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; + + if((start_x + KPerThread) <= rotary_dim) + { + constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size(); + static_assert(thread_buffer_size % KPerThread == 0); + static_for<0, thread_buffer_size, 2>{}([&](auto idx) { + const auto left = type_convert(q_tile.thread_buf_[idx]); + const auto right = type_convert(q_tile.thread_buf_[idx + 1]); + + const auto cos = + type_convert(rotary_cos_tile.thread_buf_[idx / 2]); + const auto sin = + type_convert(rotary_sin_tile.thread_buf_[idx / 2]); + + q_tile.thread_buf_[idx] = + type_convert(left * cos - right * sin); + q_tile.thread_buf_[idx + 1] = + type_convert(right * cos + left * sin); + }); + } + } + else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED + { + constexpr index_t KPerThread = 8 / sizeof(QDataType); + static_assert(kTileSizeD % KPerThread == 0); + constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; + index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; + + if((start_x + KPerThread) <= rotary_dim) + { + const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2); + + auto q_other_window = q_window; + move_tile_window(q_other_window, + {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); + auto q_other_tile = load_tile(q_other_window); + + move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)}); + auto rotary_cos_tile = load_tile(rotary_cos_window); + + move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)}); + auto rotary_sin_tile = load_tile(rotary_sin_window); + + constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size(); + static_assert(thread_buffer_size % KPerThread == 0); + static_for<0, thread_buffer_size, 1>{}([&](auto idx) { + const auto curr = type_convert(q_tile.thread_buf_[idx]); + const auto other = type_convert(q_other_tile.thread_buf_[idx]); + + const auto cos = type_convert(rotary_cos_tile.thread_buf_[idx]); + const auto sin = type_convert(rotary_sin_tile.thread_buf_[idx]); + + q_tile.thread_buf_[idx] = type_convert( + curr * cos + other * (is_left ? -sin : sin)); + }); + } + } + // print_tile(q_tile, 8); + store_tile(q_dram_block_window, q_tile); + } } } @@ -380,6 +382,8 @@ struct BlockFmhaFwdAppendKVPipeline const QRotarySinDramBlockWindow& q_rotary_sin_dram_block_window, const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window, const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window, + bool skip_q, + bool skip_kv, void* smem_ptr, index_t rotary_dim = 0) const { @@ -395,6 +399,8 @@ struct BlockFmhaFwdAppendKVPipeline q_rotary_sin_dram_block_window, knew_rotary_cos_dram_block_window, knew_rotary_sin_dram_block_window, + skip_q, + skip_kv, smem_ptr, rotary_dim); }