From 1cf868026b8deb3c397d88b80704a050a4a4066f Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sun, 14 Dec 2025 04:20:05 +0000 Subject: [PATCH] Add support of loading QK tiles of hdim96 without padding to hdim128 --- ...stu_attention_batched_forward_dispatch.hpp | 2 +- .../hstu_attention_fwd_kernel.hpp | 8 +- ..._attention_fwd_pipeline_default_policy.hpp | 231 ++++++++++++++---- ...hstu_attention_jagged_forward_dispatch.hpp | 2 +- ...hstu_attention_no_softmax_fwd_pipeline.hpp | 42 ++-- ...tention_no_softmax_fwd_trload_pipeline.hpp | 42 ++-- .../hstu_attention_pipeline_problem.hpp | 4 +- .../hstu_attention_tile_setting_define.hpp | 4 +- ...tu_attention_with_softmax_fwd_pipeline.hpp | 43 ++-- ...ntion_with_softmax_fwd_trload_pipeline.hpp | 42 ++-- 10 files changed, 252 insertions(+), 168 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index 84a80a2c93..771adccb44 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -59,7 +59,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch constexpr ck_tile::index_t occupancy = -1; const bool pad_seqlen_k = !(param.seqlen_kv % HstuAttentionTileSetting::kN0 == 0); - const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kSubQKHeaddim == 0); + const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kQKHeaddim == 0); const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0); // no need to check seqlen_q since it is not used as fastest dim, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index c07e792b26..c03627ab67 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -573,7 +573,7 @@ struct HstuAttentionFwdKernel number<1>{}); return pad_tensor_view(q_dram_naive, make_tuple(number{}, - number{}), + number{}), sequence{}); }(); const auto k_dram = [&]() { @@ -586,7 +586,7 @@ struct HstuAttentionFwdKernel return pad_tensor_view(k_dram_naive, make_tuple(number{}, - number{}), + number{}), sequence{}); }(); const auto v_dram = [&]() { @@ -624,14 +624,14 @@ struct HstuAttentionFwdKernel make_tile_window(q_dram, [&]() { return make_tuple(number{}, - number{}); + number{}); }(), {i_m0, 0}); auto k_dram_window = make_tile_window(k_dram, make_tuple(number{}, - number{}), + number{}), {0, 0}); auto v_dram_window = make_tile_window( diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index 3f66f9b7a9..8173892295 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -126,7 +126,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; return Problem::template GetDramTileAccessMaxVectorSize(); constexpr index_t kKVector = GetAlignmentK(); - if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + // for hdim96 and hdim160 + if constexpr(kKPerBlock < Problem::HstuAttentionTileSetting::kSubQKHeaddim) + { + return kKPerBlock * kNPerBlock; + } + else if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) { static_assert(kKVector == kKPack); @@ -244,11 +249,20 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds ? Problem::HstuAttentionTileSetting::kM0 : GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackQ(); constexpr index_t kKVector = GetAlignmentQ(); - if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + // for hdim96 and hdim160, use simplest layout + if constexpr(kKPerBlock < Problem::HstuAttentionTileSetting::kSubQKHeaddim) + { + return make_naive_tensor_descriptor( + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + } + else if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) { static_assert(kKVector == kKPack); @@ -331,25 +345,56 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; constexpr index_t kKVector = GetAlignmentQ(); + constexpr index_t OtherK = kKPerBlock / kKVector; - constexpr index_t KPerThread = kKVector; - constexpr index_t KThreads = kKPerBlock / KPerThread; - constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + if constexpr(kKPerBlock == Problem::HstuAttentionTileSetting::kSubQKHeaddim) + // for kKPerBlock=32,64,128,256 + { + static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!"); - // for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E] - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<1, 1>>{}); + constexpr index_t KPerThread = kKVector; + constexpr index_t KThreads = OtherK; + + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + + // for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E] + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + else // for kKPerBlock=96,160 + { + static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!"); + + // ToDo: need more considieration for hdim72 + constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5; + constexpr index_t KThreads = OtherK / KRepPerThread; + + static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!"); + + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<1, 0, 2>>{}); + }; } template @@ -357,25 +402,56 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::HstuAttentionTileSetting::kM0; - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; constexpr index_t kKVector = GetAlignmentQ(); + constexpr index_t OtherK = kKPerBlock / kKVector; - constexpr index_t KPerThread = kKVector; - constexpr index_t KThreads = kKPerBlock / KPerThread; - constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + if constexpr(kKPerBlock == Problem::HstuAttentionTileSetting::kSubQKHeaddim) + // for kKPerBlock=32,64,128,256 + { + static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!"); - // for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E] - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<1, 1>>{}); + constexpr index_t KPerThread = kKVector; + constexpr index_t KThreads = OtherK; + + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + + // for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E] + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + else // for kKPerBlock=96,160 + { + static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!"); + + // ToDo: need more considieration for hdim72 + constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5; + constexpr index_t KThreads = OtherK / KRepPerThread; + + static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!"); + + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<1, 0, 2>>{}); + }; } template @@ -383,11 +459,36 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy { constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers(); constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub; - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); - if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + // for hdim96 and hdim160, use simplest layout + if constexpr(kKPerBlock < Problem::HstuAttentionTileSetting::kSubQKHeaddim) + { + constexpr index_t KSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock; + + static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize()); + + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + else if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) { static_assert(kKVector == kKPack); @@ -500,24 +601,52 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub; - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; constexpr index_t kKVector = GetAlignmentK(); + constexpr index_t OtherK = kKPerBlock / kKVector; - constexpr index_t KPerThread = kKVector; - constexpr index_t KThreads = kKPerBlock / KPerThread; - constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); + if constexpr(kKPerBlock == Problem::HstuAttentionTileSetting::kSubQKHeaddim) + // for kKPerBlock=32,64,128,256 + { + static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!"); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); + constexpr index_t KPerThread = kKVector; + constexpr index_t KThreads = OtherK; + + constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else // for kKPerBlock=96,160 + { + static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!"); + + constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5; + constexpr index_t KThreads = OtherK / KRepPerThread; + + constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}); + }; } template diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index 04cd662a6a..856fbec32c 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -58,7 +58,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch { constexpr ck_tile::index_t occupancy = -1; - const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kSubQKHeaddim == 0); + const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kQKHeaddim == 0); const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0); // no need to check seqlen_q since it is not used as fastest dim, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp index b1c1cc71c6..e3ac8b0deb 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp @@ -37,7 +37,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim; static constexpr index_t kSubQKHeaddim = HstuAttentionTileSetting::kSubQKHeaddim; - static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static_assert(kQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); static_assert(Problem::kUseSoftmax == false, "This pipeline only works with not-using softmax"); @@ -53,7 +53,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK; - static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -127,9 +127,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS typename OAccElementFunction, typename HstuMask> CK_TILE_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, @@ -151,8 +151,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kSubQKHeaddim == - KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && @@ -184,7 +183,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), q_dram_block_window_tmp.get_window_origin(), Policy::template MakeQDramSingleRepMTileDistribution()); @@ -194,7 +193,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); @@ -226,7 +225,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); auto q_lds_write_window = make_tile_window( q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); - // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window auto q_lds_read_window = make_tile_window(q_lds, make_tuple(number{}, number{}), @@ -241,25 +239,15 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS auto k_lds_window = make_tile_window( k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); - using k_lds_write_window_type = decltype(get_slice_tile( - k_lds_window, sequence<0, 0>{}, sequence{})); - - // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window - using k_lds_read_window_type = decltype(get_slice_tile( + using k_lds_window_type = decltype(get_slice_tile( k_lds_window, sequence<0, 0>{}, sequence{})); - statically_indexed_array k_lds_write_windows; - statically_indexed_array k_lds_read_windows; + statically_indexed_array k_lds_windows; static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { - k_lds_write_windows[i_buf] = - get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{}); - k_lds_read_windows[i_buf] = - get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); + k_lds_windows[i_buf] = get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); }); // V tile in LDS @@ -382,7 +370,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS { // STAGE 1, Gemm_0 ( S = Q@K ) static_for<0, n0_loops, 1>{}([&](auto i_n0) { - store_tile(k_lds_write_windows[i_n0], k_tiles[i_n0], partition_index); + store_tile(k_lds_windows[i_n0], k_tiles[i_n0], partition_index); __builtin_amdgcn_sched_barrier(0x00000001); @@ -395,7 +383,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS block_sync_lds(); // execute current unroll of gemm_0 - gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); @@ -535,7 +523,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS typename BiasDramBlockWindowTmp, typename HstuMask> CK_TILE_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp index 9eed457527..ab760cf860 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp @@ -55,7 +55,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK; - static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -126,9 +126,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad typename OAccElementFunction, typename HstuMask> CK_TILE_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, @@ -150,8 +150,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kSubQKHeaddim == - KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && @@ -183,7 +182,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), q_dram_block_window_tmp.get_window_origin(), Policy::template MakeQDramTileDistribution()); @@ -193,7 +192,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); @@ -219,7 +218,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); auto q_lds_write_window = make_tile_window( q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); - // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window auto q_lds_read_window = make_tile_window(q_lds, make_tuple(number{}, number{}), @@ -234,25 +232,15 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad auto k_lds_window = make_tile_window( k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); - using k_lds_write_window_type = decltype(get_slice_tile( - k_lds_window, sequence<0, 0>{}, sequence{})); - - // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window - using k_lds_read_window_type = decltype(get_slice_tile( + using k_lds_window_type = decltype(get_slice_tile( k_lds_window, sequence<0, 0>{}, sequence{})); - statically_indexed_array k_lds_write_windows; - statically_indexed_array k_lds_read_windows; + statically_indexed_array k_lds_windows; static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { - k_lds_write_windows[i_buf] = - get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{}); - k_lds_read_windows[i_buf] = - get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); + k_lds_windows[i_buf] = get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); }); // V tile in LDS @@ -351,7 +339,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad { // STAGE 1, Gemm_0 ( S = Q@K ) static_for<0, n0_loops, 1>{}([&](auto i_n0) { - store_tile(k_lds_write_windows[number{}], + store_tile(k_lds_windows[number{}], k_tiles[i_n0], partition_index); @@ -366,7 +354,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad block_sync_lds(); // execute current unroll of gemm_0 - gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); @@ -481,8 +469,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad typename BiasDramBlockWindowTmp, typename HstuMask> CK_TILE_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile HstuMask mask, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp index 886bb29c1c..3efa7faca7 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp @@ -100,7 +100,7 @@ struct HstuAttentionFwdPipelineProblem CK_TILE_HOST_DEVICE static constexpr auto GetQDramTileAccessMaxVectorSize() { constexpr index_t kMPerBlock = HstuAttentionTileSetting::kM0; - constexpr index_t kKPerBlock = HstuAttentionTileSetting::kSubQKHeaddim; + constexpr index_t kKPerBlock = HstuAttentionTileSetting::kQKHeaddim; return GetDramTileAccessMaxVectorSize(); } @@ -108,7 +108,7 @@ struct HstuAttentionFwdPipelineProblem CK_TILE_HOST_DEVICE static constexpr auto GetKDramTileAccessMaxVectorSize() { constexpr index_t kNPerBlock = HstuAttentionTileSetting::kN0Sub; - constexpr index_t kKPerBlock = HstuAttentionTileSetting::kSubQKHeaddim; + constexpr index_t kKPerBlock = HstuAttentionTileSetting::kQKHeaddim; return GetDramTileAccessMaxVectorSize(); } diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp index c61ef3912d..efb9134edd 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp @@ -52,7 +52,7 @@ struct HstuAttentionFwdTileSettingClass static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen - static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // tile size along k seqlen + static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // tile size for dividing kN0 static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll static constexpr index_t kQKHeaddim = @@ -60,6 +60,8 @@ struct HstuAttentionFwdTileSettingClass // once (or repeately load Q as a whole tile) static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim); + + static_assert(kSubQKHeaddim % kN1 == 0, "Check failed!"); }; } // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index 2cceeb2591..b8f1b7e5be 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -37,7 +37,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim; static constexpr index_t kSubQKHeaddim = HstuAttentionTileSetting::kSubQKHeaddim; - static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static_assert(kQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); static_assert(Problem::kUseSoftmax == true, "This pipeline only works with using softmax"); @@ -53,7 +53,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK; - static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -127,9 +127,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS typename OAccElementFunction, typename HstuMask> CK_TILE_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, @@ -153,8 +153,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kSubQKHeaddim == - KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && @@ -198,7 +197,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), q_dram_block_window_tmp.get_window_origin(), Policy::template MakeQDramSingleRepMTileDistribution()); @@ -208,7 +207,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); @@ -245,7 +244,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); auto q_lds_write_window = make_tile_window( q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); - // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window + // when kQKHeaddim > kQKHeaddim, read window is actually smaller than write window auto q_lds_read_window = make_tile_window(q_lds, make_tuple(number{}, number{}), @@ -260,25 +259,15 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS auto k_lds_window = make_tile_window( k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); - using k_lds_write_window_type = decltype(get_slice_tile( - k_lds_window, sequence<0, 0>{}, sequence{})); - - // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window - using k_lds_read_window_type = decltype(get_slice_tile( + using k_lds_window_type = decltype(get_slice_tile( k_lds_window, sequence<0, 0>{}, sequence{})); - statically_indexed_array k_lds_write_windows; - statically_indexed_array k_lds_read_windows; + statically_indexed_array k_lds_windows; static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { - k_lds_write_windows[i_buf] = - get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{}); - k_lds_read_windows[i_buf] = - get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); + k_lds_windows[i_buf] = get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); }); // V tile in LDS @@ -401,7 +390,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS { // STAGE 1, Gemm_0 ( S = Q@K ) static_for<0, n0_loops, 1>{}([&](auto i_n0) { - store_tile(k_lds_write_windows[number{}], + store_tile(k_lds_windows[number{}], k_tiles[number{}], partition_index); @@ -428,7 +417,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS block_sync_lds(); // execute current unroll of gemm_0 - gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); @@ -673,7 +662,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS typename BiasDramBlockWindowTmp, typename HstuMask> CK_TILE_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 23eb04f39c..da05271e04 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -55,7 +55,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK; - static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -126,9 +126,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad typename OAccElementFunction, typename HstuMask> CK_TILE_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, @@ -152,8 +152,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kSubQKHeaddim == - KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && @@ -197,7 +196,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), q_dram_block_window_tmp.get_window_origin(), Policy::template MakeQDramTileDistribution()); @@ -207,7 +206,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); @@ -241,7 +240,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); auto q_lds_write_window = make_tile_window( q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); - // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window auto q_lds_read_window = make_tile_window(q_lds, make_tuple(number{}, number{}), @@ -255,25 +253,15 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad auto k_lds_window = make_tile_window( k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); - using k_lds_write_window_type = decltype(get_slice_tile( - k_lds_window, sequence<0, 0>{}, sequence{})); - - // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window - using k_lds_read_window_type = decltype(get_slice_tile( + using k_lds_window_type = decltype(get_slice_tile( k_lds_window, sequence<0, 0>{}, sequence{})); - statically_indexed_array k_lds_write_windows; - statically_indexed_array k_lds_read_windows; + statically_indexed_array k_lds_windows; static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { - k_lds_write_windows[i_buf] = - get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{}); - k_lds_read_windows[i_buf] = - get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); + k_lds_windows[i_buf] = get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); }); // V tile in LDS @@ -372,7 +360,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad { // STAGE 1, Gemm_0 ( S = Q@K ) static_for<0, n0_loops, 1>{}([&](auto i_n0) { - store_tile(k_lds_write_windows[number{}], + store_tile(k_lds_windows[number{}], k_tiles[number{}], partition_index); @@ -399,7 +387,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad block_sync_lds(); // execute current unroll of gemm_0 - gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); @@ -638,8 +626,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad typename BiasDramBlockWindowTmp, typename HstuMask> CK_TILE_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile HstuMask mask,