diff --git a/include/ck_tile/ops/fmha/block/block_dropout.hpp b/include/ck_tile/ops/fmha/block/block_dropout.hpp index 8abdd54cd9..1512c6ae34 100644 --- a/include/ck_tile/ops/fmha/block/block_dropout.hpp +++ b/include/ck_tile/ops/fmha/block/block_dropout.hpp @@ -333,23 +333,15 @@ struct BlockDropout return randval; }; - if(is_store_randval) - { - static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { - static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { - const auto randval = generate_randval(i_m0, i_n0); - // save to Global - const auto randval_store = cast_tile(randval); - store_tile(randval_dram_window, randval_store); - move_tile_window(randval_dram_window, {0, kNPerStep}); - }); - move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); - }); - move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); - } static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { const auto randval = generate_randval(i_m0, i_n0); + if(is_store_randval) + { + const auto randval_store = cast_tile(randval); + store_tile(randval_dram_window, randval_store); + } + move_tile_window(randval_dram_window, {0, kNPerStep}); // Drop values of P based on the generated probabilities constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { @@ -369,7 +361,9 @@ struct BlockDropout }); }); }); + move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); }); + move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); } const unsigned long long ph_seed; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index d991d5fe25..3b476299e1 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -1005,7 +1005,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel rand_val_ptr, make_tuple(kargs.seqlen_q, kargs.seqlen_k), make_tuple(kargs.stride_randval, 1), - number<1>{}, + number{}, number<1>{}); return pad_tensor_view(randval_dram_naive, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index fe7c8d48c8..fba3065842 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1450,7 +1450,7 @@ struct FmhaFwdKernel rand_val_ptr, make_tuple(kargs.seqlen_q, kargs.seqlen_k), make_tuple(kargs.stride_randval, 1), - number<1>{}, + number{}, number<1>{}); return pad_tensor_view(randval_dram_naive, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index b01c127a21..7a8e9a1d47 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -80,6 +80,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + static constexpr index_t kAlignmentRandVal = + kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal(); static constexpr index_t kBlockPerCu = []() { if constexpr(Problem::kBlockPerCu != -1) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 0836fbfce3..9ec82617b1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -83,6 +83,8 @@ struct BlockFmhaPipelineQRKSVS kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + static constexpr index_t kAlignmentRandVal = + kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal(); static constexpr index_t kBlockPerCu = []() { if constexpr(Problem::kBlockPerCu != -1) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index ba788c7f1e..b67c28401f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -81,6 +81,8 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + static constexpr index_t kAlignmentRandVal = + kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal(); #if CK_TILE_FMHA_FWD_FAST_EXP2 static constexpr auto R_LOG2E = 1.0 / log2e_v; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 1d998ba4f6..08fc42a471 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -90,6 +90,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + static constexpr index_t kAlignmentRandVal = + kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal(); static constexpr index_t kBlockPerCu = []() { if constexpr(Problem::kBlockPerCu != -1) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index a1b1e0e158..74e91aac56 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -69,6 +69,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + static constexpr index_t kAlignmentRandVal = + kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal(); static constexpr index_t kBlockPerCu = []() { if constexpr(Problem::kBlockPerCu != -1) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index 074a94613c..1283782d06 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -74,6 +74,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + static constexpr index_t kAlignmentRandVal = + kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal(); static constexpr index_t kBlockPerCu = []() { if constexpr(Problem::kBlockPerCu != -1) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 4efcd871dc..8309c1ec2a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -78,6 +78,8 @@ struct BlockFmhaPipelineQSKSVS kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + static constexpr index_t kAlignmentRandVal = + kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal(); static constexpr index_t kBlockPerCu = []() { if constexpr(Problem::kBlockPerCu != -1) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 06c6dce6b0..692a1cfa13 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -422,6 +422,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentRandVal() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + + constexpr auto c_warp_y_lengths = CWarpDstr{}.get_ys_to_d_descriptor().get_lengths(); + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::RandValOutputDataType); + return min(MaxVectorSize, c_warp_y_lengths.get(number{})); + } + template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() {