diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index d0ff5c5707..a16bec1d37 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -579,15 +579,22 @@ struct HstuAttentionFwdKernel ck_tile::index_t num_tile_in_seqlen = ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0); - if constexpr(kUseGroup) + // when kHasDropout is true, we should not give special consideration to minfull_attn_seqlen + // > 0, since BlockDropout requires each workgroup has a kM0 aligned i_m0 position + if constexpr(!kHasDropout) { - num_tile_in_seqlen += 1; - } - else - { - if(has_minfull_attn_seqlen) + if constexpr(kUseGroup) + { + // always assume minfull_attn_seqlen > 0 since we don't know its exact value + // for each group num_tile_in_seqlen += 1; - }; + } + else + { + if(has_minfull_attn_seqlen) + num_tile_in_seqlen += 1; + }; + } if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kQKHeaddim) { @@ -763,35 +770,44 @@ struct HstuAttentionFwdKernel bool is_tile_in_first_split = true; index_t i_m0; - if(kargs.min_full_attn_seqlen > 0) + // when kHasDropout is true, we should not give special consideration to minfull_attn_seqlen + // > 0, since BlockDropout requires each workgroup has a kM0 aligned i_m0 position + if constexpr(!kHasDropout) { - // need consider for cases where min_full_attn_seqlen be bigger than max_uih_len - if(kargs.seqlen_q - num_target > kargs.min_full_attn_seqlen) + if(kargs.min_full_attn_seqlen > 0) { - seqlen_in_first_split = kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen; + // need consider for cases where min_full_attn_seqlen be bigger than max_uih_len + if(kargs.seqlen_q - num_target > kargs.min_full_attn_seqlen) + { + seqlen_in_first_split = + kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen; - index_t num_tile_in_first_split = - __builtin_amdgcn_readfirstlane(ck_tile::integer_divide_ceil( - seqlen_in_first_split, HstuAttentionPipeline::kM0)); + index_t num_tile_in_first_split = + __builtin_amdgcn_readfirstlane(ck_tile::integer_divide_ceil( + seqlen_in_first_split, HstuAttentionPipeline::kM0)); - is_tile_in_first_split = (i_tile_m < num_tile_in_first_split); + is_tile_in_first_split = (i_tile_m < num_tile_in_first_split); - i_m0 = is_tile_in_first_split - ? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0) - : __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) * - HstuAttentionPipeline::kM0) + - seqlen_in_first_split; + i_m0 = + is_tile_in_first_split + ? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0) + : __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) * + HstuAttentionPipeline::kM0) + + seqlen_in_first_split; + } + else + { + seqlen_in_first_split = 0; + is_tile_in_first_split = false; + + // adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor + kargs.min_full_attn_seqlen = kargs.seqlen_q - num_target; + + i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0); + }; } else - { - seqlen_in_first_split = 0; - is_tile_in_first_split = false; - - // adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor - kargs.min_full_attn_seqlen = kargs.seqlen_q - num_target; - i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0); - }; } else i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0); @@ -953,6 +969,36 @@ struct HstuAttentionFwdKernel } }(); + auto null_randval_window = [&]() { + if constexpr(kHasDropout) + { + // need to make a tile window from this null_randval_dram since the null_tile_window + // does not have store_tile() over-loaded, which will cause compiling issue when + // used inside BlockDropout::Run(). Also we need this dram window to provide + // start_m0_idx used in BlockDropout::Run() + const auto null_randval_dram = [&]() { + const auto null_dram_naive = make_naive_tensor_view( + static_cast(nullptr), + make_tuple(seqlen_q_in_ctrl, kargs.seqlen_kv), + make_tuple(kargs.seqlen_kv, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(null_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + return make_tile_window(null_randval_dram, + make_tuple(number{}, + number{}), + {i_m0, 0}); + } + else + return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); + }(); + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { if constexpr(kHasDropout) { @@ -1014,6 +1060,7 @@ struct HstuAttentionFwdKernel k_dram_window, v_dram_window, bias_dram_window, + null_randval_window, seqlen_k_start, seqlen_k_end, mask, @@ -1029,6 +1076,7 @@ struct HstuAttentionFwdKernel v_dram_window, bias_dram_window, lse_dram_window, + null_randval_window, seqlen_k_start, seqlen_k_end, mask, @@ -1067,6 +1115,7 @@ struct HstuAttentionFwdKernel k_dram_window, v_dram_window, bias_dram_window, + null_randval_window, seqlen_k_start, seqlen_k_end, mask, @@ -1082,6 +1131,7 @@ struct HstuAttentionFwdKernel v_dram_window, bias_dram_window, lse_dram_window, + null_randval_window, seqlen_k_start, seqlen_k_end, mask, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_policy.hpp index ffbc453cf8..51e80173ef 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_policy.hpp @@ -592,6 +592,75 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy constexpr index_t WarpGemmK = Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{}); +#ifdef __gfx950__ + static_assert((WarpGemmM == 16 && WarpGemmK == 32) || + (WarpGemmM == 32 && WarpGemmK == 16), + "Not supported WarpGemm sizes!"); +#else + static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)) || + (WarpGemmM == 32 && (WarpGemmK == 8 || WarpGemmK == 16)), + "Not supported WarpGemm sizes!"); +#endif + + return WarpGemmDispatcher< + typename Problem::QKVDataType, + typename Problem::QKVDataType, + typename Problem::GemmAccDataType, + Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{}), + Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<1>{}), + Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Single>{}; + } + else + { + static_assert(false, "Not supported data types!"); + } + }(); + + using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy< + typename Problem::QKVDataType, + typename Problem::QKVDataType, + typename Problem::GemmAccDataType, + typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps, + decltype(warp_gemm)>; + + if constexpr(1 < Problem::kNumGemm0Warps) + return BlockGemmARegBSmemCRegV2Hack_0{}; + else + return BlockGemmARegBSmemCRegOneWarpV1{}; + } + + // Same as GetQKBlockGemm but with kN0 (instead of kN0Sub) as the N tile dimension. + // This is used as the BlockGemm template argument to BlockDropout::Run() so that + // kNPerBlock = kN0, ensuring dropout is applied to the full pcomp_tile [kM0, kN0] + // rather than only the first kN0Sub columns. + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKCombinedBlockGemm() + { + using GemmProblem = BlockGemmProblem< + typename Problem::QKVDataType, + typename Problem::QKVDataType, + typename Problem::GemmAccDataType, + Problem::kNumGemm0Warps * get_warp_size(), + TileGemmShape, + typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps, + typename Problem::HstuAttentionTileSetting::Gemm0WarpTile>>; + + auto warp_gemm = [&]() { + if constexpr((std::is_same_v || + std::is_same_v) && + std::is_same_v) + { + constexpr index_t WarpGemmM = + Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{}); + constexpr index_t WarpGemmK = + Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{}); + #ifdef __gfx950__ static_assert((WarpGemmM == 16 && WarpGemmK == 32) || (WarpGemmM == 32 && WarpGemmK == 16), diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp index ec33bca893..3c7f180d5d 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp @@ -540,15 +540,22 @@ struct HstuAttentionFwdSplitKVKernel ck_tile::index_t num_tile_in_seqlen = ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0); - if constexpr(kUseGroup) + // when kHasDropout is true, we should not give special consideration to minfull_attn_seqlen + // > 0, since BlockDropout requires each workgroup has a kM0 aligned i_m0 position + if constexpr(!kHasDropout) { - num_tile_in_seqlen += 1; - } - else - { - if(has_minfull_attn_seqlen) + if constexpr(kUseGroup) + { + // always assume minfull_attn_seqlen > 0 since we don't know its exact value + // for each group num_tile_in_seqlen += 1; - }; + } + else + { + if(has_minfull_attn_seqlen) + num_tile_in_seqlen += 1; + }; + } if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kQKHeaddim) { @@ -760,35 +767,44 @@ struct HstuAttentionFwdSplitKVKernel bool is_tile_in_first_split = true; index_t i_m0; - if(kargs.min_full_attn_seqlen > 0) + // when kHasDropout is true, we should not give special consideration to minfull_attn_seqlen + // > 0, since BlockDropout requires each workgroup has a kM0 aligned i_m0 position + if constexpr(!kHasDropout) { - // need consider for cases where min_full_attn_seqlen be bigger than max_uih_len - if(kargs.seqlen_q - num_target > kargs.min_full_attn_seqlen) + if(kargs.min_full_attn_seqlen > 0) { - seqlen_in_first_split = kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen; + // need consider for cases where min_full_attn_seqlen be bigger than max_uih_len + if(kargs.seqlen_q - num_target > kargs.min_full_attn_seqlen) + { + seqlen_in_first_split = + kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen; - index_t num_tile_in_first_split = - __builtin_amdgcn_readfirstlane(ck_tile::integer_divide_ceil( - seqlen_in_first_split, HstuAttentionPipeline::kM0)); + index_t num_tile_in_first_split = + __builtin_amdgcn_readfirstlane(ck_tile::integer_divide_ceil( + seqlen_in_first_split, HstuAttentionPipeline::kM0)); - is_tile_in_first_split = (i_tile_m < num_tile_in_first_split); + is_tile_in_first_split = (i_tile_m < num_tile_in_first_split); - i_m0 = is_tile_in_first_split - ? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0) - : __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) * - HstuAttentionPipeline::kM0) + - seqlen_in_first_split; + i_m0 = + is_tile_in_first_split + ? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0) + : __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) * + HstuAttentionPipeline::kM0) + + seqlen_in_first_split; + } + else + { + seqlen_in_first_split = 0; + is_tile_in_first_split = false; + + // adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor + kargs.min_full_attn_seqlen = kargs.seqlen_q - num_target; + + i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0); + }; } else - { - seqlen_in_first_split = 0; - is_tile_in_first_split = false; - - // adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor - kargs.min_full_attn_seqlen = kargs.seqlen_q - num_target; - i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0); - }; } else i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0); @@ -956,6 +972,36 @@ struct HstuAttentionFwdSplitKVKernel } }(); + auto null_randval_window = [&]() { + if constexpr(kHasDropout) + { + // need to make a tile window from this null_randval_dram since the null_tile_window + // does not have store_tile() over-loaded, which will cause compiling issue when + // used inside BlockDropout::Run(). Also we need this dram window to provide + // start_m0_idx used in BlockDropout::Run() + const auto null_randval_dram = [&]() { + const auto null_dram_naive = make_naive_tensor_view( + static_cast(nullptr), + make_tuple(seqlen_q_in_ctrl, kargs.seqlen_kv), + make_tuple(kargs.seqlen_kv, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(null_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + return make_tile_window(null_randval_dram, + make_tuple(number{}, + number{}), + {i_m0, 0}); + } + else + return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); + }(); + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { if constexpr(kHasDropout) { @@ -1020,6 +1066,7 @@ struct HstuAttentionFwdSplitKVKernel k_dram_window, v_dram_window, bias_dram_window, + null_randval_window, seqlen_k_start, seqlen_k_end, mask, @@ -1035,6 +1082,7 @@ struct HstuAttentionFwdSplitKVKernel v_dram_window, bias_dram_window, lse_acc_dram_window, + null_randval_window, seqlen_k_start, seqlen_k_end, mask, @@ -1076,6 +1124,7 @@ struct HstuAttentionFwdSplitKVKernel k_dram_window, v_dram_window, bias_dram_window, + null_randval_window, seqlen_k_start, seqlen_k_end, mask, @@ -1091,6 +1140,7 @@ struct HstuAttentionFwdSplitKVKernel v_dram_window, bias_dram_window, lse_acc_dram_window, + null_randval_window, seqlen_k_start, seqlen_k_end, mask, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp index 61510219f5..6d964c89ac 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp @@ -114,12 +114,14 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename NullRandValDramWindowTmp, typename HstuMask> CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + NullRandValDramWindowTmp& null_randval_window_tmp, // M0*N0 tile index_t seqlen_k_start, index_t seqlen_k_end, HstuMask& mask, @@ -155,6 +157,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetPVTBlockGemm(); + using Gemm0Combined = decltype(Policy::template GetQKCombinedBlockGemm()); + // SaccBlockTile size is [kM0, kN0Sub] // PcompBlockTile size is [kM0, kN0] using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); @@ -265,31 +269,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); - auto null_randval_window = [&]() { - if constexpr(kHasDropout) - { - // need to make a tile window from this null_randval_dram since the null_tile_window - // does not have store_tile() over-loaded, will cause compiling issue when used - // inside BlockDropout::Run() - const auto null_randval_dram = [&]() { - const auto null_dram_naive = make_naive_tensor_view( - static_cast(nullptr), - make_tuple(1, 1), - make_tuple(1, 1), - number<1>{}, - number<1>{}); - - return pad_tensor_view(null_dram_naive, - make_tuple(number<1>{}, number<1>{}), - sequence{}); - }(); - - return make_tile_window( - null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0}); - } - else - return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); - }(); + auto null_randval_window = dropout.template MakeRandvalDramWindow( + null_randval_window_tmp, seqlen_k_start); clear_tile(o_acc); @@ -396,17 +377,21 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS tile_elementwise_inout([&](auto& x) { x = x * type_convert(scale_p); }, pcomp_tile); - seqlen_k_curr += kN0; - if constexpr(kHasDropout) { + __builtin_amdgcn_sched_barrier(0); + auto randval_lds_ptr = reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - dropout.template Run( + dropout.template Run( randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); + + __builtin_amdgcn_sched_barrier(0); } + seqlen_k_curr += kN0; + auto p = cast_tile(pcomp_tile); // STAGE 3, Gemm_1 ( O = P@V ) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp index 587581188c..d0b4cba89d 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp @@ -115,12 +115,14 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename NullRandValDramWindowTmp, typename HstuMask> CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + NullRandValDramWindowTmp& null_randval_window_tmp, // M0*N0 tile index_t seqlen_k_start, index_t seqlen_k_end, HstuMask& mask, @@ -156,6 +158,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetPVTBlockGemm(); + using Gemm0Combined = decltype(Policy::template GetQKCombinedBlockGemm()); + // SaccBlockTile size is [kM0, kN0Sub] // PcompBlockTile size is [kM0, kN0] using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); @@ -272,31 +276,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); - auto null_randval_window = [&]() { - if constexpr(kHasDropout) - { - // need to make a tile window from this null_randval_dram since the null_tile_window - // does not have store_tile() over-loaded, will cause compiling issue when used - // inside BlockDropout::Run() - const auto null_randval_dram = [&]() { - const auto null_dram_naive = make_naive_tensor_view( - static_cast(nullptr), - make_tuple(1, 1), - make_tuple(1, 1), - number<1>{}, - number<1>{}); - - return pad_tensor_view(null_dram_naive, - make_tuple(number<1>{}, number<1>{}), - sequence{}); - }(); - - return make_tile_window( - null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0}); - } - else - return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); - }(); + auto null_randval_window = dropout.template MakeRandvalDramWindow( + null_randval_window_tmp, seqlen_k_start); clear_tile(o_acc); @@ -383,18 +364,22 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad detail::scale_tile_in_pack(pcomp_tile, scale_p); - seqlen_k_curr += kN0; - if constexpr(kHasDropout) { + __builtin_amdgcn_sched_barrier(0); + auto randval_lds_ptr = reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - dropout.template Run( + dropout.template Run( randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); + + __builtin_amdgcn_sched_barrier(0); } + seqlen_k_curr += kN0; + auto p = cast_tile(pcomp_tile); // check whether first V-LdsBufer overlap with last K-LdsBuffer, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index bb47a2f81c..e2896755a7 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -119,6 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename LSEorLSEaccDramBlockWindow, + typename NullRandValDramWindowTmp, typename HstuMask> CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile @@ -126,6 +127,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile + NullRandValDramWindowTmp& null_randval_window_tmp, // M0*N0 tile index_t seqlen_k_start, index_t seqlen_k_end, HstuMask& mask, @@ -164,6 +166,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetPVTBlockGemm(); + using Gemm0Combined = decltype(Policy::template GetQKCombinedBlockGemm()); + // SaccBlockTile size is [kM0, kN0Sub] // PcompBlockTile size is [kM0, kN0] using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); @@ -296,31 +300,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); - auto null_randval_window = [&]() { - if constexpr(kHasDropout) - { - // need to make a tile window from this null_randval_dram since the null_tile_window - // does not have store_tile() over-loaded, will cause compiling issue when used - // inside BlockDropout::Run() - const auto null_randval_dram = [&]() { - const auto null_dram_naive = make_naive_tensor_view( - static_cast(nullptr), - make_tuple(1, 1), - make_tuple(1, 1), - number<1>{}, - number<1>{}); - - return pad_tensor_view(null_dram_naive, - make_tuple(number<1>{}, number<1>{}), - sequence{}); - }(); - - return make_tile_window( - null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0}); - } - else - return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); - }(); + auto null_randval_window = dropout.template MakeRandvalDramWindow( + null_randval_window_tmp, seqlen_k_start); clear_tile(o_acc); set_tile(m, -numeric::infinity()); @@ -522,17 +503,21 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS l(i_idx) = l[i_idx] + rowsum_p[i_idx]; }); - seqlen_k_curr += kN0; - if constexpr(kHasDropout) { + __builtin_amdgcn_sched_barrier(0); + auto randval_lds_ptr = reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - dropout.template Run( + dropout.template Run( randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); + + __builtin_amdgcn_sched_barrier(0); } + seqlen_k_curr += kN0; + auto p = cast_tile(pcomp_tile); __builtin_amdgcn_sched_barrier(0x00000001); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 2291d0fac9..d9cc319668 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -119,6 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename LSEorLSEaccDramBlockWindow, + typename NullRandValDramWindowTmp, typename HstuMask> CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile @@ -126,6 +127,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile + NullRandValDramWindowTmp& null_randval_window_tmp, // M0*N0 tile index_t seqlen_k_start, index_t seqlen_k_end, HstuMask& mask, @@ -165,6 +167,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad constexpr auto gemm_1 = Policy::template GetPVTBlockGemm(); + using Gemm0Combined = decltype(Policy::template GetQKCombinedBlockGemm()); + // SaccBlockTile size is [kM0, kN0Sub] // PcompBlockTile size is [kM0, kN0] using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); @@ -303,31 +307,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); - auto null_randval_window = [&]() { - if constexpr(kHasDropout) - { - // need to make a tile window from this null_randval_dram since the null_tile_window - // does not have store_tile() over-loaded, will cause compiling issue when used - // inside BlockDropout::Run() - const auto null_randval_dram = [&]() { - const auto null_dram_naive = make_naive_tensor_view( - static_cast(nullptr), - make_tuple(1, 1), - make_tuple(1, 1), - number<1>{}, - number<1>{}); - - return pad_tensor_view(null_dram_naive, - make_tuple(number<1>{}, number<1>{}), - sequence{}); - }(); - - return make_tile_window( - null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0}); - } - else - return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); - }(); + auto null_randval_window = dropout.template MakeRandvalDramWindow( + null_randval_window_tmp, seqlen_k_start); clear_tile(o_acc); set_tile(m, -numeric::infinity()); @@ -528,18 +509,22 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad l(i_idx) = l[i_idx] + rowsum_p[i_idx]; }); - seqlen_k_curr += kN0; - if constexpr(kHasDropout) { + __builtin_amdgcn_sched_barrier(0); + auto randval_lds_ptr = reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - dropout.template Run( + dropout.template Run( randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); + + __builtin_amdgcn_sched_barrier(0); } + seqlen_k_curr += kN0; + auto p = cast_tile(pcomp_tile); __builtin_amdgcn_sched_barrier(0x00000001); diff --git a/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_with_dropout.sh b/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_with_dropout.sh new file mode 100644 index 0000000000..eef7fe6b8f --- /dev/null +++ b/example/ck_tile/18_hstu_attention/scripts/test_hstu_attention_with_dropout.sh @@ -0,0 +1,62 @@ +#!/bin/bash + +BUILD=build +EXE="$BUILD/bin/tile_example_hstu_attention -p_drop=0.2" + +attn_scale=0 +if [ $# -ge 1 ]; then + attn_scale=$1 +fi + +ndist=0 + +if [ $# -ge 2 ]; then + ndist=$2 +fi + +for dtype in "fp16" "bf16"; do + set -x + + ## no masking batched + $EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## no masking jagged + $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## batched causal + $EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal + $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## batched causal+local + $EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal+local + $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## batched causal+local+context + $EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal+local+context + $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist + + ## batched causal+local+context+target + $EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal+local+context+target + $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged no-causal+local+context+target + $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal+local+target (minfull_len > max_uih_len) + $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged causal+local+context+target (minfull_len > max_uih_len) + $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + + ## jagged no-causal+local+context+target (minfull_len > max_uih_len) + $EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist + set +x +done