diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 4ed1c53570..e3fd8f73d2 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -788,6 +788,7 @@ struct HstuAttentionFwdKernel is_tile_in_first_split = (i_tile_m < num_tile_in_first_split); + // be careful that i_m0 for second_split could be not aligned on kM0 i_m0 = is_tile_in_first_split ? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0) @@ -1050,9 +1051,10 @@ struct HstuAttentionFwdKernel }(); const auto [seqlen_k_start, seqlen_k_end] = - mask.template GetTileRangeAlongX(i_m0, - number{}, - number{}); + mask.template GetTileRangeAlongX( + i_m0, + number{}, + number{}); if constexpr(!kUseSoftmax) { @@ -1105,9 +1107,10 @@ struct HstuAttentionFwdKernel }(); const auto [seqlen_k_start, seqlen_k_end] = - mask.template GetTileRangeAlongX(i_m0, - number{}, - number{}); + mask.template GetTileRangeAlongX( + i_m0, + number{}, + number{}); if constexpr(!kUseSoftmax) { diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_policy.hpp index 51e80173ef..7fa242c9b9 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_policy.hpp @@ -818,12 +818,16 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy { if constexpr(Problem::kHasDropout) { - using BlockGemm = remove_cvref_t())>; + using BlockGemm = remove_cvref_t())>; constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; constexpr index_t MWarps = config.template at<1>(); - constexpr index_t kMPerStep = MWarps * WG::kM; - // assume all warps are assigned on dim-M + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarps * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarps * WG::kM; + // assume the all warps are assigned on dim-M constexpr index_t kNPerStep = WG::kN; return (kMPerStep + 1) * kNPerStep * sizeof(uint8_t); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp index b1dc9ac81b..5961621512 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp @@ -785,6 +785,7 @@ struct HstuAttentionFwdSplitKVKernel is_tile_in_first_split = (i_tile_m < num_tile_in_first_split); + // be careful i_m0 for second_split could be not aligned on kM0 i_m0 = is_tile_in_first_split ? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0) @@ -1053,9 +1054,10 @@ struct HstuAttentionFwdSplitKVKernel }(); const auto [global_seqlen_k_start, global_seqlen_k_end] = - mask.template GetTileRangeAlongX(i_m0, - number{}, - number{}); + mask.template GetTileRangeAlongX( + i_m0, + number{}, + number{}); const auto [seqlen_k_start, seqlen_k_end] = CalculateTileRangeAlongXForSplit( global_seqlen_k_start, global_seqlen_k_end, kargs.num_splits, i_split); @@ -1111,9 +1113,10 @@ struct HstuAttentionFwdSplitKVKernel }(); const auto [global_seqlen_k_start, global_seqlen_k_end] = - mask.template GetTileRangeAlongX(i_m0, - number{}, - number{}); + mask.template GetTileRangeAlongX( + i_m0, + number{}, + number{}); const auto [seqlen_k_start, seqlen_k_end] = CalculateTileRangeAlongXForSplit( global_seqlen_k_start, global_seqlen_k_end, kargs.num_splits, i_split); diff --git a/example/ck_tile/18_hstu_attention/scripts/test_hstu_cross_attention_with_dropout.sh b/example/ck_tile/18_hstu_attention/scripts/test_hstu_cross_attention_with_dropout.sh new file mode 100644 index 0000000000..aba54d207c --- /dev/null +++ b/example/ck_tile/18_hstu_attention/scripts/test_hstu_cross_attention_with_dropout.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +BUILD=build + +USE_SOFTMAX=0 +if [ $# -ge 1 ]; then + USE_SOFTMAX=$1 +fi + +Training=${TEST_HSTU_FWD_TRAINING:-0} + +if [ $USE_SOFTMAX -eq 1 ]; then + EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training -p_drop=0.2" +else + EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training -p_drop=0.2" +fi + +for T in "fp16" "bf16"; do + set -x + + ## no masking batched + $EXE -v=1 -prec=$T -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -seqlens_kv=300 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=0 -norm_dist=0 + + ## no masking jagged + $EXE -v=1 -prec=$T -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -seqlens_kv=300 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=0 -norm_dist=0 + + ## batched causal + $EXE -v=1 -prec=$T -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -seqlens_kv=300 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=0 -norm_dist=0 + + ## jagged causal + $EXE -v=1 -prec=$T -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -seqlens_kv=300 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=0 -norm_dist=0 + + ## batched causal+local + $EXE -v=1 -prec=$T -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -seqlens_kv=300 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=0 -norm_dist=0 + + ## jagged causal+local + $EXE -v=1 -prec=$T -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -seqlens_kv=300 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=0 -norm_dist=0 + + ## batched causal+local+context + $EXE -v=1 -prec=$T -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -seqlens_kv=300 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=0 -norm_dist=0 + + ## jagged causal+local+context + $EXE -v=1 -prec=$T -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -seqlens_kv=300 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=0 -norm_dist=0 + + ## batched causal+local+context+target + $EXE -v=1 -prec=$T -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -seqlens_kv=300 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=0 -norm_dist=0 + + ## jagged causal+local+context+target + $EXE -v=1 -prec=$T -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -seqlens_kv=300 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=0 -norm_dist=0 + + ## jagged no-causal+local+context+target + $EXE -v=1 -prec=$T -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -seqlens_kv=300 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=0 -norm_dist=0 + + ## jagged causal+local+target (minfull_len > max_uih_len) + $EXE -v=1 -prec=$T -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -seqlens_kv=300 -causal=1 -local_len=5 -context_len=0 -minfull_len=290 -targets=8 -attn_scale=0 -norm_dist=0 + + ## jagged causal+local+context+target (minfull_len > max_uih_len) + $EXE -v=1 -prec=$T -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -seqlens_kv=300 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=0 -norm_dist=0 + + ## jagged no-causal+local+context+target (minfull_len > max_uih_len) + $EXE -v=1 -prec=$T -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -seqlens_kv=300 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=0 -norm_dist=0 + set +x +done + +## This case is used to verify the masking when seqlen_kv > seqlen_q by comparing the saved mask tensor with the output of test_pytorch_hstu_mask_v2.py +$EXE -v=1 -prec=bf16 -b=3 -jagged=1 -nhead=1 -hdim_qk=128 -hdim_v=128 -seqlens=52,55,58 -seqlens_kv=70,76,80 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=4,5,6 -attn_scale=0 -norm_dist=0 -save_mask=1