diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index ef6eb20a3e..b4311483cd 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1098,7 +1098,7 @@ bool run(const ck_tile::ArgParser& arg_parser) else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); // optionally apply RoPE to the q_host_ref - if(false && 0 < rotary_dim) + if(0 < rotary_dim) { decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths()); @@ -1107,6 +1107,26 @@ bool run(const ck_tile::ArgParser& arg_parser) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); } + #if 1 + HOST_DEBUG_STMTS { + printf("\n"); + for(size_t row = 0; row < q_host_ref.get_length(1) && row < 8; ++row) + { + printf("[HOST] q_host_ref[%3zu] = ", row); + for(size_t col = 0; col < q_host_ref.get_length(2); ++col) + { + if (0 < col && col % 8 == 0) { + printf("|"); + } + + printf("%11.7f", + ck_tile::type_convert(q_host_ref(0, row, col))); + + } + printf("\n"); + } + } + #endif if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); @@ -1114,8 +1134,6 @@ bool run(const ck_tile::ArgParser& arg_parser) // copy Knew to the end of K if(0 < seqlen_knew) { - printf("\n"); - ck_tile::HostTensor knew_host_ref({nhead, seqlen_knew, hdim_q}); if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[0] / nr, i[1], i[2]); }); else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[1], i[0] / nr, i[2]); }); @@ -1125,6 +1143,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::optional knew_host_ref_ro; #if 0 HOST_DEBUG_STMTS { + printf("\n"); for(size_t row = 0; row < real_knew_host_ref->get_length(1); ++row) { printf("[HOST] real_knew_host[%3zu] = ", row); @@ -1157,6 +1176,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } #if 0 HOST_DEBUG_STMTS { + printf("\n"); for(size_t row = 0; row < real_knew_host_ref->get_length(1); ++row) { printf("[HOST] real_knew_host_ref[%3zu] = ", row); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 4d65bfed62..ba7e9472a3 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -518,56 +518,110 @@ struct FmhaFwdAppendKVKernel sequence{}); } }(); - constexpr auto rotary_cos_sin_dram_window_lengths = - make_tuple(number{}, number{}); - const auto rotary_cos_dram_window = [&]() { + + constexpr auto q_rotary_cos_sin_dram_window_lengths = + make_tuple(number{}, number{}); + const auto q_rotary_cos_dram_window = [&]() { if constexpr(kApplyRoPE) { - const auto rotary_cos_dram = [&]() { - const auto rotary_cos_dram_native = - make_naive_tensor_view( - reinterpret_cast(kargs.rotary_cos_ptr), - make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2), - make_tuple(kargs.rotary_dim / 2, 1), - number<8>{}, - number<1>{}); + const auto rotary_cos_dram_native = + make_naive_tensor_view( + reinterpret_cast(kargs.rotary_cos_ptr), + make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2), + make_tuple(kargs.rotary_dim / 2, 1), + number<8>{}, + number<1>{}); + const auto rotary_cos_dram = [&]() { return pad_tensor_view(rotary_cos_dram_native, - rotary_cos_sin_dram_window_lengths, - sequence{}); + q_rotary_cos_sin_dram_window_lengths, + sequence{}); }(); return make_tile_window( - rotary_cos_dram, rotary_cos_sin_dram_window_lengths, {0, 0}); + rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {0, 0}); } else { - return make_null_tile_window(rotary_cos_sin_dram_window_lengths); + return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths); } }(); - const auto rotary_sin_dram_window = [&]() { + const auto q_rotary_sin_dram_window = [&]() { if constexpr(kApplyRoPE) { - const auto rotary_sin_dram = [&]() { - const auto rotary_sin_dram_native = - make_naive_tensor_view( - reinterpret_cast(kargs.rotary_sin_ptr), - make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2), - make_tuple(kargs.rotary_dim / 2, 1), - number<8>{}, - number<1>{}); + const auto rotary_sin_dram_native = + make_naive_tensor_view( + reinterpret_cast(kargs.rotary_sin_ptr), + make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2), + make_tuple(kargs.rotary_dim / 2, 1), + number<8>{}, + number<1>{}); + const auto rotary_sin_dram = [&]() { return pad_tensor_view(rotary_sin_dram_native, - rotary_cos_sin_dram_window_lengths, - sequence{}); + q_rotary_cos_sin_dram_window_lengths, + sequence{}); }(); return make_tile_window( - rotary_sin_dram, rotary_cos_sin_dram_window_lengths, {0, 0}); + rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {0, 0}); } else { - return make_null_tile_window(rotary_cos_sin_dram_window_lengths); + return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths); + } + }(); + + constexpr auto knew_rotary_cos_sin_dram_window_lengths = + make_tuple(number{}, number{}); + const auto knew_rotary_cos_dram_window = [&]() { + if constexpr(kApplyRoPE) + { + const auto rotary_cos_dram_native = + make_naive_tensor_view( + reinterpret_cast(kargs.rotary_cos_ptr), + make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2), + make_tuple(kargs.rotary_dim / 2, 1), + number<8>{}, + number<1>{}); + + const auto rotary_cos_dram = [&]() { + return pad_tensor_view(rotary_cos_dram_native, + knew_rotary_cos_sin_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window( + rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {0, 0}); + } + else + { + return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths); + } + }(); + const auto knew_rotary_sin_dram_window = [&]() { + if constexpr(kApplyRoPE) + { + const auto rotary_sin_dram_native = + make_naive_tensor_view( + reinterpret_cast(kargs.rotary_sin_ptr), + make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2), + make_tuple(kargs.rotary_dim / 2, 1), + number<8>{}, + number<1>{}); + + const auto rotary_sin_dram = [&]() { + return pad_tensor_view(rotary_sin_dram_native, + knew_rotary_cos_sin_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window( + rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {0, 0}); + } + else + { + return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths); } }(); @@ -618,8 +672,10 @@ struct FmhaFwdAppendKVKernel knew_dram_window, v_dram_window, vnew_dram_window, - rotary_cos_dram_window, - rotary_sin_dram_window, + q_rotary_cos_dram_window, + q_rotary_sin_dram_window, + knew_rotary_cos_dram_window, + knew_rotary_sin_dram_window, smem_ptr, kargs.rotary_dim); } @@ -630,8 +686,10 @@ struct FmhaFwdAppendKVKernel knew_dram_window, v_dram_window, vnew_dram_window, - rotary_cos_dram_window, - rotary_sin_dram_window, + q_rotary_cos_dram_window, + q_rotary_sin_dram_window, + knew_rotary_cos_dram_window, + knew_rotary_sin_dram_window, smem_ptr); } } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index 416f321487..2706c9002a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -84,8 +84,10 @@ struct BlockFmhaFwdAppendKVPipeline typename QElementFunction, typename KnewElementFunction, typename VnewElementFunction, - typename RotaryCosDramBlockWindow, - typename RotarySinDramBlockWindow> + typename QRotaryCosDramBlockWindow, + typename QRotarySinDramBlockWindow, + typename KnewRotaryCosDramBlockWindow, + typename KnewRotarySinDramBlockWindow> CK_TILE_HOST_DEVICE auto operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile const QElementFunction& q_element_func, @@ -95,8 +97,10 @@ struct BlockFmhaFwdAppendKVPipeline VDramBlockWindow& v_dram_block_window, // N1*K1 tile const VnewDramBlockWindow& vnew_dram_block_window, // N1*K1 tile const VnewElementFunction& vnew_element_func, - const RotaryCosDramBlockWindow rotary_cos_dram_block_window, - const RotarySinDramBlockWindow rotary_sin_dram_block_window, + const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window, + const QRotarySinDramBlockWindow q_rotary_sin_dram_block_window, + const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window, + const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window, void* smem_ptr, index_t rotary_dim = 0) const { @@ -169,15 +173,15 @@ struct BlockFmhaFwdAppendKVPipeline if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE) { auto rotary_cos_window = - make_tile_window(rotary_cos_dram_block_window.get_bottom_tensor_view(), - rotary_cos_dram_block_window.get_window_lengths(), - rotary_cos_dram_block_window.get_window_origin(), + make_tile_window(knew_rotary_cos_dram_block_window.get_bottom_tensor_view(), + knew_rotary_cos_dram_block_window.get_window_lengths(), + knew_rotary_cos_dram_block_window.get_window_origin(), Policy::template MakeRotaryCosSinTileDistribution()); auto rotary_sin_window = - make_tile_window(rotary_sin_dram_block_window.get_bottom_tensor_view(), - rotary_sin_dram_block_window.get_window_lengths(), - rotary_sin_dram_block_window.get_window_origin(), + make_tile_window(knew_rotary_sin_dram_block_window.get_bottom_tensor_view(), + knew_rotary_sin_dram_block_window.get_window_lengths(), + knew_rotary_sin_dram_block_window.get_window_origin(), Policy::template MakeRotaryCosSinTileDistribution()); // We assume that each thread owns contiguous elements on head dimention. And we will @@ -274,15 +278,85 @@ struct BlockFmhaFwdAppendKVPipeline auto q = load_tile(q_window); return tile_elementwise_in(q_element_func, q); }(); - print_tile(q_tile, 8); - /// TODO: add rotary_cos/rotary_sin windows for Q (tile size: M0xK0) + + auto rotary_cos_window = + make_tile_window(q_rotary_cos_dram_block_window.get_bottom_tensor_view(), + q_rotary_cos_dram_block_window.get_window_lengths(), + q_rotary_cos_dram_block_window.get_window_origin(), + Policy::template MakeRotaryCosSinTileDistribution()); + + auto rotary_sin_window = + make_tile_window(q_rotary_sin_dram_block_window.get_bottom_tensor_view(), + q_rotary_sin_dram_block_window.get_window_lengths(), + q_rotary_sin_dram_block_window.get_window_origin(), + Policy::template MakeRotaryCosSinTileDistribution()); + // We assume that each thread owns contiguous elements on head dimention. And we will - // use the distribution to enable/disable threads in order to override knew_tile content - if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) {} + // use the distribution to enable/disable threads in order to override q_tile content + if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) + { + auto rotary_cos_tile = load_tile(rotary_cos_window); + auto rotary_sin_tile = load_tile(rotary_sin_window); + + constexpr index_t KPerThread = 16 / sizeof(QDataType); + static_assert(kTileSizeD % KPerThread == 0); + constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; + index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; + + if((start_x + KPerThread) <= rotary_dim) + { + constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size(); + static_assert(thread_buffer_size % KPerThread == 0); + static_for<0, thread_buffer_size, 2>{}([&](auto idx) { + const auto left = type_convert(q_tile.thread_buf_[idx]); + const auto right = type_convert(q_tile.thread_buf_[idx + 1]); + + const auto cos = type_convert(rotary_cos_tile.thread_buf_[idx / 2]); + const auto sin = type_convert(rotary_sin_tile.thread_buf_[idx / 2]); + + q_tile.thread_buf_[idx] = type_convert(left * cos - right * sin); + q_tile.thread_buf_[idx + 1] = + type_convert(right * cos + left * sin); + }); + } + } else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED { - } + constexpr index_t KPerThread = 8 / sizeof(QDataType); + static_assert(kTileSizeD % KPerThread == 0); + constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; + index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; + if((start_x + KPerThread) <= rotary_dim) + { + const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2); + + auto q_other_window = q_window; + move_tile_window(q_other_window, + {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); + auto q_other_tile = load_tile(q_other_window); + + move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)}); + auto rotary_cos_tile = load_tile(rotary_cos_window); + + move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)}); + auto rotary_sin_tile = load_tile(rotary_sin_window); + + constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size(); + static_assert(thread_buffer_size % KPerThread == 0); + static_for<0, thread_buffer_size, 1>{}([&](auto idx) { + const auto curr = type_convert(q_tile.thread_buf_[idx]); + const auto other = type_convert(q_other_tile.thread_buf_[idx]); + + const auto cos = type_convert(rotary_cos_tile.thread_buf_[idx]); + const auto sin = type_convert(rotary_sin_tile.thread_buf_[idx]); + + q_tile.thread_buf_[idx] = + type_convert(curr * cos + other * (is_left ? -sin : sin)); + }); + } + } + print_tile(q_tile, 8); store_tile(q_dram_block_window, q_tile); } } @@ -292,16 +366,20 @@ struct BlockFmhaFwdAppendKVPipeline typename KnewDramBlockWindow, typename VDramBlockWindow, typename VnewDramBlockWindow, - typename RotaryCosDramBlockWindow, - typename RotarySinDramBlockWindow> + typename QRotaryCosDramBlockWindow, + typename QRotarySinDramBlockWindow, + typename KnewRotaryCosDramBlockWindow, + typename KnewRotarySinDramBlockWindow> CK_TILE_HOST_DEVICE auto operator()(QDramBlockWindow& q_dram_block_window, KDramBlockWindow& k_dram_block_window, const KnewDramBlockWindow& knew_dram_block_window, VDramBlockWindow& v_dram_block_window, const VnewDramBlockWindow& vnew_dram_block_window, - const RotaryCosDramBlockWindow& rotary_cos_dram_block_window, - const RotarySinDramBlockWindow& rotary_sin_dram_block_window, + const QRotaryCosDramBlockWindow& q_rotary_cos_dram_block_window, + const QRotarySinDramBlockWindow& q_rotary_sin_dram_block_window, + const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window, + const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window, void* smem_ptr, index_t rotary_dim = 0) const { @@ -313,8 +391,10 @@ struct BlockFmhaFwdAppendKVPipeline v_dram_block_window, vnew_dram_block_window, identity{}, - rotary_cos_dram_block_window, - rotary_sin_dram_block_window, + q_rotary_cos_dram_block_window, + q_rotary_sin_dram_block_window, + knew_rotary_cos_dram_block_window, + knew_rotary_sin_dram_block_window, smem_ptr, rotary_dim); }