diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index ecca0e752d..795b9349da 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -363,9 +363,10 @@ struct HstuAttentionFwdPipelineQRKSVS if constexpr(HstuMask::kUseLocal) { constexpr index_t V_VMEM_LOAD_INST = (kN1 * kK1) / kBlockSize / kAlignmentV; - constexpr index_t K_VMEM_LOAD_INST = (kN0 * kK0) / kBlockSize / kAlignmentV; - constexpr index_t K_LDS_WRITE_INST = - (kN0 * kK0) / kBlockSize / Policy::template GetSmemKPackK(); + constexpr index_t K_VMEM_LOAD_INST = + (kK1 * kSubQKHeaddim) / kBlockSize / kAlignmentK; + constexpr index_t K_LDS_WRITE_INST = (kK1 * kSubQKHeaddim) / kBlockSize / + Policy::template GetSmemKPackK(); constexpr index_t MFMA_INST = (kM0 * kSubQKHeaddim) / kBlockSize / 4; constexpr index_t K_LDS_READ_INST = MFMA_INST / kGemmNumRepM; @@ -401,7 +402,7 @@ struct HstuAttentionFwdPipelineQRKSVS __builtin_amdgcn_sched_group_barrier(0x00000008, kGemmNumRepM, 0); - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier(0x00000001); } else {