diff --git a/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp b/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp index ba056108ac..5173279299 100644 --- a/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp +++ b/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp @@ -34,4 +34,75 @@ struct RotaryEmbeddingEnumToStr static constexpr const char* name = "half"; }; +template +struct BlockRotaryEmbedding +{ + template + CK_TILE_HOST_DEVICE static void apply(DistributedTensor& tile, + OtherDramBlockWindow other_window, + RotaryCosDramBlockWindow rotary_cos_window, + RotarySinDramBlockWindow rotary_sin_window, + index_t rotary_dim, + index_t thread_end) + { + using DataType = typename remove_cvref_t::DataType; + + if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED) + { + auto rotary_cos_tile = load_tile(rotary_cos_window); + auto rotary_sin_tile = load_tile(rotary_sin_window); + + if(thread_end <= rotary_dim) + { + constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size(); + static_for<0, thread_buffer_size, 2>{}([&](auto idx) { + const auto left = type_convert(tile.thread_buf_[idx]); + const auto right = type_convert(tile.thread_buf_[idx + 1]); + + const auto cos = + type_convert(rotary_cos_tile.thread_buf_[idx / 2]); + const auto sin = + type_convert(rotary_sin_tile.thread_buf_[idx / 2]); + + tile.thread_buf_[idx] = type_convert(left * cos - right * sin); + tile.thread_buf_[idx + 1] = type_convert(right * cos + left * sin); + }); + } + } + else if constexpr(RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) + { + if(thread_end <= rotary_dim) + { + const bool is_left = (thread_end <= (rotary_dim / 2)); + + move_tile_window(other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); + auto other_tile = load_tile(other_window); + + move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)}); + auto rotary_cos_tile = load_tile(rotary_cos_window); + + move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)}); + auto rotary_sin_tile = load_tile(rotary_sin_window); + + constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size(); + static_for<0, thread_buffer_size, 1>{}([&](auto idx) { + const auto curr = type_convert(tile.thread_buf_[idx]); + const auto other = type_convert(other_tile.thread_buf_[idx]); + + const auto cos = + type_convert(rotary_cos_tile.thread_buf_[idx]); + const auto sin = + type_convert(rotary_sin_tile.thread_buf_[idx]); + + tile.thread_buf_[idx] = + type_convert(curr * cos + other * (is_left ? -sin : sin)); + }); + } + } + } +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index 768ac08628..383227bb07 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -182,69 +182,18 @@ struct BlockFmhaFwdAppendKVPipeline Policy::template MakeRotaryCosSinTileDistribution()); // We assume that each thread owns contiguous elements on head dimention. And we - // will use the distribution to enable/disable threads in order to override + // will use the distribution to enable/disable threads in order to override partial // knew_tile content auto [thread_start, thread_end] = Policy::template GetKnewThreadRangeAlongK(); ignore = thread_start; - if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED) - { - auto rotary_cos_tile = load_tile(rotary_cos_window); - auto rotary_sin_tile = load_tile(rotary_sin_window); - - if(thread_end <= rotary_dim) - { - constexpr index_t thread_buffer_size = - decltype(knew_tile.thread_buf_)::size(); - static_for<0, thread_buffer_size, 2>{}([&](auto idx) { - const auto left = type_convert(knew_tile.thread_buf_[idx]); - const auto right = type_convert(knew_tile.thread_buf_[idx + 1]); - - const auto cos = - type_convert(rotary_cos_tile.thread_buf_[idx / 2]); - const auto sin = - type_convert(rotary_sin_tile.thread_buf_[idx / 2]); - - knew_tile.thread_buf_[idx] = - type_convert(left * cos - right * sin); - knew_tile.thread_buf_[idx + 1] = - type_convert(right * cos + left * sin); - }); - } - } - else // RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED - { - if(thread_end <= rotary_dim) - { - const bool is_left = (thread_end <= (rotary_dim / 2)); - - auto knew_other_window = knew_window; - move_tile_window(knew_other_window, - {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); - auto knew_other_tile = load_tile(knew_other_window); - - move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)}); - auto rotary_cos_tile = load_tile(rotary_cos_window); - - move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)}); - auto rotary_sin_tile = load_tile(rotary_sin_window); - - constexpr index_t thread_buffer_size = - decltype(knew_tile.thread_buf_)::size(); - static_for<0, thread_buffer_size, 1>{}([&](auto idx) { - const auto curr = type_convert(knew_tile.thread_buf_[idx]); - const auto other = - type_convert(knew_other_tile.thread_buf_[idx]); - - const auto cos = type_convert(rotary_cos_tile.thread_buf_[idx]); - const auto sin = type_convert(rotary_sin_tile.thread_buf_[idx]); - - knew_tile.thread_buf_[idx] = type_convert( - curr * cos + other * (is_left ? -sin : sin)); - }); - } - } + BlockRotaryEmbedding::apply(knew_tile, + knew_window, + rotary_cos_window, + rotary_sin_window, + rotary_dim, + thread_end); } print_tile(knew_tile, 2); store_tile(k_dram_block_window, knew_tile); @@ -281,65 +230,14 @@ struct BlockFmhaFwdAppendKVPipeline Policy::template MakeRotaryCosSinTileDistribution()); // We assume that each thread owns contiguous elements on head dimention. And we - // will use the distribution to enable/disable threads in order to override q_tile - // content + // will use the distribution to enable/disable threads in order to override partial + // q_tile content auto [thread_start, thread_end] = Policy::template GetQThreadRangeAlongK(); ignore = thread_start; - if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED) - { - auto rotary_cos_tile = load_tile(rotary_cos_window); - auto rotary_sin_tile = load_tile(rotary_sin_window); + BlockRotaryEmbedding::apply( + q_tile, q_window, rotary_cos_window, rotary_sin_window, rotary_dim, thread_end); - if(thread_end <= rotary_dim) - { - constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size(); - static_for<0, thread_buffer_size, 2>{}([&](auto idx) { - const auto left = type_convert(q_tile.thread_buf_[idx]); - const auto right = type_convert(q_tile.thread_buf_[idx + 1]); - - const auto cos = - type_convert(rotary_cos_tile.thread_buf_[idx / 2]); - const auto sin = - type_convert(rotary_sin_tile.thread_buf_[idx / 2]); - - q_tile.thread_buf_[idx] = - type_convert(left * cos - right * sin); - q_tile.thread_buf_[idx + 1] = - type_convert(right * cos + left * sin); - }); - } - } - else // RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED - { - if(thread_end <= rotary_dim) - { - const bool is_left = (thread_end <= (rotary_dim / 2)); - - auto q_other_window = q_window; - move_tile_window(q_other_window, - {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); - auto q_other_tile = load_tile(q_other_window); - - move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)}); - auto rotary_cos_tile = load_tile(rotary_cos_window); - - move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)}); - auto rotary_sin_tile = load_tile(rotary_sin_window); - - constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size(); - static_for<0, thread_buffer_size, 1>{}([&](auto idx) { - const auto curr = type_convert(q_tile.thread_buf_[idx]); - const auto other = type_convert(q_other_tile.thread_buf_[idx]); - - const auto cos = type_convert(rotary_cos_tile.thread_buf_[idx]); - const auto sin = type_convert(rotary_sin_tile.thread_buf_[idx]); - - q_tile.thread_buf_[idx] = type_convert( - curr * cos + other * (is_left ? -sin : sin)); - }); - } - } // print_tile(q_tile, 8); store_tile(q_dram_block_window, q_tile); }