diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 86566a2ed0..e0c21d26c5 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -160,6 +160,19 @@ struct HstuAttentionFwdPipelineQRKSVS constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using PcompBlockTileType = decltype(cast_tile(SaccBlockTileType{})); + + SaccBlockTileType sacc_tile; + PcompBlockTileType pcomp_tile; + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + OaccBlockTileType o_acc; + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_origin(), @@ -177,6 +190,8 @@ struct HstuAttentionFwdPipelineQRKSVS auto q_tile = load_tile(q_dram_window); + clear_tile(o_acc); + auto k_tile = load_tile(k_dram_window); move_tile_window(k_dram_window, {kK1, 0}); @@ -200,11 +215,6 @@ struct HstuAttentionFwdPipelineQRKSVS sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); }); - auto v_dram_window = - make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), - v_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_k_start}, // TODO: hdim split? - Policy::template MakeVDramTileDistribution()); // V tile in LDS auto v_lds = make_tensor_view( reinterpret_cast(smem_ptr), @@ -222,15 +232,11 @@ struct HstuAttentionFwdPipelineQRKSVS v_lds_window, sequence{}, sequence<(i_buf + 1) * kN1, kK1>{}); }); - // Block GEMM - constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - - using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); - using PcompBlockTileType = decltype(cast_tile(SaccBlockTileType{})); - - SaccBlockTileType sacc_tile; - PcompBlockTileType pcomp_tile; + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); // reduction function for softmax const auto f_silu = [&](CompDataType& x) { @@ -246,13 +252,6 @@ struct HstuAttentionFwdPipelineQRKSVS } }; - using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); - - // init Oacc, M, L - auto o_acc = OaccBlockTileType{}; - - clear_tile(o_acc); - const auto num_loops = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); // check early exit if no work to do