diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index b6054f6285..98d2a2ec62 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -542,6 +542,38 @@ bool run(const ck_tile::ArgParser& arg_parser) auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin(shape_seqlen_k, rotary_dim, seed); + HOST_DEBUG_STMTS + { +#if 0 + printf("rotary_cos's shape: (%2zu, %2zu)\n", + rotary_cos_host.get_length(0), + rotary_cos_host.get_length(1)); + for(size_t row = 0; row < rotary_cos_host.get_length(0); ++row) + { + printf("[HOST] rotary_cos[%3zu] = ", row); + for(size_t col = 0; col < rotary_cos_host.get_length(1); ++col) + { + printf("%11.7f", ck_tile::type_convert(rotary_cos_host(row, col))); + } + printf("\n"); + } +#endif +#if 1 + printf("rotary_sin's shape: (%2zu, %2zu)\n", + rotary_sin_host.get_length(0), + rotary_sin_host.get_length(1)); + for(size_t row = 0; row < rotary_sin_host.get_length(0); ++row) + { + printf("[HOST] rotary_sin[%3zu] = ", row); + for(size_t col = 0; col < rotary_sin_host.get_length(1); ++col) + { + printf("%11.7f", ck_tile::type_convert(rotary_sin_host(row, col))); + } + printf("\n"); + } +#endif + } + ck_tile::HostTensor lse_acc_host( 1 < num_splits ? std::array{num_splits, batch, nhead, max_seqlen_q} : std::array{1, 1, 1, 1}); @@ -580,11 +612,10 @@ bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(knew_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(vnew_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); - - std::fill(knew_host.begin(), knew_host.end(), static_cast(99.f)); - std::fill(vnew_host.begin(), vnew_host.end(), static_cast(99.f)); } else if(init_method == "nf") { @@ -1057,7 +1088,7 @@ bool run(const ck_tile::ArgParser& arg_parser) else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); // optionally apply RoPE to the q_host_ref - if(0 < rotary_dim) + if(false && 0 < rotary_dim) { decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths()); @@ -1073,6 +1104,8 @@ bool run(const ck_tile::ArgParser& arg_parser) // copy Knew to the end of K if(0 < seqlen_knew) { + printf("\n"); + ck_tile::HostTensor knew_host_ref({nhead, seqlen_knew, hdim_q}); if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[0] / nr, i[1], i[2]); }); else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[1], i[0] / nr, i[2]); }); @@ -1094,6 +1127,19 @@ bool run(const ck_tile::ArgParser& arg_parser) real_knew_host_ref = &knew_host_ref_ro.value(); } + HOST_DEBUG_STMTS { + for(size_t row = 0; row < real_knew_host_ref->get_length(1); ++row) + { + printf("[HOST] real_knew_host_ref[%3zu] = ", row); + for(size_t col = 0; col < real_knew_host_ref->get_length(2); ++col) + { + printf("%11.7f", + ck_tile::type_convert((*real_knew_host_ref)(0, row, col))); + } + printf("\n"); + } + } + const std::size_t knew_start = real_seqlen_k - seqlen_knew; k_host_ref.ForEach([&](auto& self, auto i) { if(knew_start <= i[1]) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index aac9eefc54..fec294ea5e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -104,6 +104,13 @@ struct BlockFmhaFwdAppendKVPipeline index_t rotary_dim = 0, bool is_rotary_interleaved = false) const { + + auto* const ksmem = reinterpret_cast(smem_ptr); + if(threadIdx.x == 0) + { + printf("\n"); + } + (void)q_dram_block_window_tmp; (void)q_element_func; (void)k_dram_block_window_tmp; @@ -132,7 +139,109 @@ struct BlockFmhaFwdAppendKVPipeline Policy::template MakeKnewDramTileDistribution()); auto knew_tile = load_tile(knew_dram_window); - /// TODO: apply RoPE on knew_tile here + + if constexpr(kApplyRoPE) + { + auto rotary_cos_window = make_tile_window( + rotary_cos_block_window_tmp.get_bottom_tensor_view(), + rotary_cos_block_window_tmp.get_window_lengths(), + rotary_cos_block_window_tmp.get_window_origin(), + Policy::template MakeRotaryCosSinInterleaveDramTileDistribution()); + + auto rotary_sin_window = make_tile_window( + rotary_sin_block_window_tmp.get_bottom_tensor_view(), + rotary_sin_block_window_tmp.get_window_lengths(), + rotary_sin_block_window_tmp.get_window_origin(), + Policy::template MakeRotaryCosSinInterleaveDramTileDistribution()); + + auto rotary_cos_tile = load_tile(rotary_cos_window); + auto rotary_sin_tile = load_tile(rotary_sin_window); + +#if 0 + auto rotary_tile = rotary_sin_tile; + constexpr auto spans = decltype(rotary_tile)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + rotary_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = tile_idx.at(number<0>{}); + const auto col = tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + ksmem[row * (kTileSizeD / 2) + col] = rotary_tile(i_j_idx); + }); + }); + + block_sync_lds(); + + DEVICE_DEBUG_STMTS + { + for(int row = 0; row < kTileSizeSk; ++row) + { + printf("[DEVICE] rotary_smem[%3d] = ", row); + for(int col = 0; col < rotary_dim / 2; ++col) + { + printf("%11.7f", type_convert(ksmem[row * (kTileSizeD / 2) + col])); + } + printf("\n"); + } + } +#endif + constexpr index_t KPerThread = 16 / sizeof(KDataType); + static_assert(kTileSizeD % KPerThread == 0); + constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; + index_t start_x = (threadIdx.x % KThreadPerBlock); + if(start_x + KPerThread <= rotary_dim) + { + constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size(); + static_assert(thread_buffer_size % KPerThread == 0); + static_for<0, thread_buffer_size, 2>{}([&](auto idx) { + const auto left = type_convert(knew_tile.thread_buf_[idx]); + const auto right = type_convert(knew_tile.thread_buf_[idx + 1]); + + const auto cos = type_convert(rotary_cos_tile.thread_buf_[idx / 2]); + const auto sin = type_convert(rotary_sin_tile.thread_buf_[idx / 2]); + + knew_tile.thread_buf_[idx] = left * cos - right * sin; + knew_tile.thread_buf_[idx + 1] = right * cos + left * sin; + }); + } +#if 0 + DEVICE_DEBUG_STMTS { printf("[DEVICE] kTileSizeD: %3d\n", kTileSizeD); } + + { + constexpr auto spans = decltype(knew_tile)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + knew_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = tile_idx.at(number<0>{}); + const auto col = tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + ksmem[row * kTileSizeD + col] = knew_tile(i_j_idx); + }); + }); + } + + block_sync_lds(); + + DEVICE_DEBUG_STMTS + { + for(int row = 0; row < 7; ++row) + { + printf("[DEVICE] knew_tile[%3d] = ", row); + for(int col = 0; col < kTileSizeD; ++col) + { + printf("%11.7f", type_convert(ksmem[row * kTileSizeD + col])); + } + printf("\n"); + } + } +#endif + } store_tile(k_dram_block_window_tmp, knew_tile); auto vnew_dram_block_window = diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp index cd38c22e8f..fc664dbd62 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp @@ -54,7 +54,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy { using KDataType = remove_cvref_t; - return sizeof(KDataType) * Problem::kTileSizeSk * (Problem::kTileSizeD / 2); + return sizeof(KDataType) * Problem::kTileSizeSk * (Problem::kTileSizeD); } template