diff --git a/include/ck_tile/ops/fmha/block/block_dropout.hpp b/include/ck_tile/ops/fmha/block/block_dropout.hpp index 1f0fe2bd64..7ebb306cce 100644 --- a/include/ck_tile/ops/fmha/block/block_dropout.hpp +++ b/include/ck_tile/ops/fmha/block/block_dropout.hpp @@ -8,6 +8,20 @@ namespace ck_tile { +struct NullBlockDropout +{ + template + __host__ __device__ static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + (void)randval_dram_block_window_tmp; + (void)seqlen_qk_start; + + return make_null_tile_window(make_tuple(number<0>{}, number<0>{})); + } +}; + struct BlockDropout { CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, @@ -195,6 +209,42 @@ struct BlockDropout MakeRandValLdsShuffleTileDistribution()); const int start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{}); + if(is_store_randval) + { + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); + int block_col_start = (start_n0_idx / WG::kN) + i_n0; + uint2 rowcol = make_uint2(block_row_start, block_col_start); + + // generate random number + uint8_t random_uint8_t[16]; + ph.get_random_16x8(random_uint8_t, + reinterpret_cast(rowcol)); + + constexpr auto randval_dist_generated_spans = + decltype(randval_dist_generated)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; + }); + }); + // save to LDS + store_tile(randval_lds_window, randval_dist_generated); + block_sync_lds(); + // read from LDS to register + auto randval = load_tile(randval_lds_read_window); + // save to Global + const auto randval_store = cast_tile(randval); + store_tile(randval_dram_window, randval_store); + move_tile_window(randval_dram_window, {0, kNPerStep}); + }); + move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); + }); + move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); + }; static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); @@ -232,23 +282,8 @@ struct BlockDropout : PComputeDataType(0); }); }); - // save to Global - if(is_store_randval) - { - const auto randval_store = cast_tile(randval); - store_tile(randval_dram_window, randval_store); - move_tile_window(randval_dram_window, {0, kNPerStep}); - } }); - if(is_store_randval) - { - move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); - } }); - if(is_store_randval) - { - move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); - } } template ::max(); - uint64_t drop_seed = 0; - uint64_t drop_offset = 0; - bool is_store_randval = false; - - if constexpr(kHasDropout) - { - rp_undrop = kargs.rp_undrop; - p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; - drop_seed = kargs.drop_seed; - drop_offset = kargs.drop_offset; - is_store_randval = kargs.is_store_randval; - } - BlockDropout dropout(i_batch, - i_nhead, - kargs.num_head_q, - drop_seed, - drop_offset, - rp_undrop, - p_undrop_in_uint8_t, - is_store_randval); + auto dropout = [&]() { + if constexpr(kHasDropout) + { + return BlockDropout{i_batch, + i_nhead, + kargs.num_head_q, + kargs.drop_seed, + kargs.drop_offset, + kargs.rp_undrop, + kargs.p_undrop_in_uint8_t, + kargs.is_store_randval}; + } + else + { + return NullBlockDropout{}; + }; + }(); auto randval_dram_window = [&, i_nhead_ = i_nhead]() { constexpr auto randval_dram_window_lengths = diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 06ce3a6514..a392f0124d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -100,6 +100,8 @@ struct BlockFmhaPipelineQRKSVS static constexpr const char* name = "qr"; + using DropoutType = std::conditional_t; + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return Policy::template GetSmemSize(); @@ -139,7 +141,7 @@ struct BlockFmhaPipelineQRKSVS PositionEncoding position_encoding, float scale_s, void* smem_ptr, - BlockDropout& dropout) const + DropoutType& dropout) const { static_assert( std::is_same_v> && @@ -246,7 +248,7 @@ struct BlockFmhaPipelineQRKSVS {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); - auto randval_dram_window = dropout.MakeRandvalDramWindow( + auto randval_dram_window = dropout.template MakeRandvalDramWindow( randval_dram_block_window_tmp, seqlen_k_start); auto v_dram_window = @@ -486,7 +488,7 @@ struct BlockFmhaPipelineQRKSVS if constexpr(kHasDropout) { - dropout.Run( + dropout.template Run( smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); } @@ -618,7 +620,7 @@ struct BlockFmhaPipelineQRKSVS PositionEncoding position_encoding, float scale_s, void* smem_ptr, - BlockDropout& dropout) const + DropoutType& dropout) const { return operator()(q_dram_block_window_tmp, identity{}, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 21784fc2d2..e9a14ca5ac 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -112,6 +112,8 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr const char* name = "qr_async"; + using DropoutType = std::conditional_t; + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return Policy::template GetSmemSize(); @@ -151,7 +153,7 @@ struct BlockFmhaPipelineQRKSVSAsync PositionEncoding position_encoding, float scale_s, void* smem_ptr, - BlockDropout& dropout) const + DropoutType& dropout) const { static_assert( std::is_same_v> && @@ -298,7 +300,7 @@ struct BlockFmhaPipelineQRKSVSAsync {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); - auto randval_dram_window = dropout.MakeRandvalDramWindow( + auto randval_dram_window = dropout.template MakeRandvalDramWindow( randval_dram_block_window_tmp, seqlen_k_start); auto v_dram_window = @@ -571,7 +573,7 @@ struct BlockFmhaPipelineQRKSVSAsync { auto randval_ptr = reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - dropout.Run( + dropout.template Run( randval_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, @@ -728,7 +730,7 @@ struct BlockFmhaPipelineQRKSVSAsync PositionEncoding position_encoding, float scale_s, void* smem_ptr, - BlockDropout& dropout) const + DropoutType& dropout) const { return operator()(q_dram_block_window_tmp, identity{},