diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index f54f3ac0c9..a69c0fe394 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -145,7 +145,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy using KDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1; + constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); @@ -199,7 +199,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize() { - constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1; + constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); @@ -401,7 +401,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() { constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers(); - constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1; + constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); @@ -508,7 +508,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy using QKVDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1; + constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType); @@ -719,7 +719,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy typename Problem::GemmAccDataType, Problem::kNumGemm0Warps * get_warp_size(), TileGemmShape, typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps, typename Problem::HstuAttentionTileSetting::Gemm0WarpTile>>; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp index 6bbd20b3c7..f80d7df1b6 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp @@ -16,12 +16,12 @@ using HstuAttentionFwdWarpTile3 = ck_tile::sequence<32, 32, 16>; template struct HstuAttentionNoSoftmaxFwdBlockTile; -// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) +// Tile-sizes: M N0 N0Sub N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) // template <> struct HstuAttentionNoSoftmaxFwdBlockTile<32> { - using type = ck_tile::sequence<64, 64, 32, 32, 32>; + using type = ck_tile::sequence<64, 64, 32, 32, 32, 32>; using gemm0_warps = ck_tile::sequence<2, 1, 1>; using gemm1_warps = ck_tile::sequence<2, 1, 1>; }; @@ -29,7 +29,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<32> template <> struct HstuAttentionNoSoftmaxFwdBlockTile<64> { - using type = ck_tile::sequence<128, 64, 64, 32, 64>; + using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -37,7 +37,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<64> template <> struct HstuAttentionNoSoftmaxFwdBlockTile<128> { - using type = ck_tile::sequence<128, 32, 128, 16, 128>; + using type = ck_tile::sequence<128, 32, 16, 128, 16, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -45,7 +45,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<128> template <> struct HstuAttentionNoSoftmaxFwdBlockTile<256> { - using type = ck_tile::sequence<128, 32, 256, 16, 256>; + using type = ck_tile::sequence<128, 32, 16, 256, 16, 256>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -53,12 +53,12 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<256> template struct HstuAttentionWithSoftmaxFwdBlockTile; -// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) +// Tile-sizes: M N0 N0Sub N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) // template <> struct HstuAttentionWithSoftmaxFwdBlockTile<32> { - using type = ck_tile::sequence<64, 64, 32, 32, 32>; + using type = ck_tile::sequence<64, 64, 32, 32, 32, 32>; using gemm0_warps = ck_tile::sequence<2, 1, 1>; using gemm1_warps = ck_tile::sequence<2, 1, 1>; }; @@ -66,7 +66,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<32> template <> struct HstuAttentionWithSoftmaxFwdBlockTile<64> { - using type = ck_tile::sequence<128, 64, 64, 32, 64>; + using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -74,7 +74,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<64> template <> struct HstuAttentionWithSoftmaxFwdBlockTile<128> { - using type = ck_tile::sequence<128, 64, 128, 16, 128>; + using type = ck_tile::sequence<128, 64, 16, 128, 16, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -82,7 +82,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<128> template <> struct HstuAttentionWithSoftmaxFwdBlockTile<256> { - using type = ck_tile::sequence<128, 32, 256, 16, 256>; + using type = ck_tile::sequence<128, 32, 16, 256, 16, 256>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -186,12 +186,12 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<256> template struct HstuAttentionNoSoftmaxFwdBlockTile; -// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) +// Tile-sizes: M N0 N0Sub N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) // template <> struct HstuAttentionNoSoftmaxFwdBlockTile<32> { - using type = ck_tile::sequence<64, 64, 32, 32, 32>; + using type = ck_tile::sequence<64, 64, 32, 32, 32, 32>; using gemm0_warps = ck_tile::sequence<2, 1, 1>; using gemm1_warps = ck_tile::sequence<2, 1, 1>; }; @@ -199,7 +199,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<32> template <> struct HstuAttentionNoSoftmaxFwdBlockTile<64> { - using type = ck_tile::sequence<128, 64, 64, 32, 64>; + using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -207,7 +207,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<64> template <> struct HstuAttentionNoSoftmaxFwdBlockTile<128> { - using type = ck_tile::sequence<128, 32, 128, 32, 128>; + using type = ck_tile::sequence<128, 32, 32, 128, 32, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -215,7 +215,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<128> template <> struct HstuAttentionNoSoftmaxFwdBlockTile<256> { - using type = ck_tile::sequence<128, 32, 256, 32, 256>; + using type = ck_tile::sequence<128, 32, 32, 256, 32, 256>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -223,12 +223,12 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<256> template struct HstuAttentionWithSoftmaxFwdBlockTile; -// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) +// Tile-sizes: M N0 N0Sub N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) // template <> struct HstuAttentionWithSoftmaxFwdBlockTile<32> { - using type = ck_tile::sequence<64, 64, 32, 32, 32>; + using type = ck_tile::sequence<64, 64, 32, 32, 32, 32>; using gemm0_warps = ck_tile::sequence<2, 1, 1>; using gemm1_warps = ck_tile::sequence<2, 1, 1>; }; @@ -236,7 +236,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<32> template <> struct HstuAttentionWithSoftmaxFwdBlockTile<64> { - using type = ck_tile::sequence<128, 64, 64, 32, 64>; + using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -244,7 +244,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<64> template <> struct HstuAttentionWithSoftmaxFwdBlockTile<128> { - using type = ck_tile::sequence<128, 64, 128, 32, 128>; + using type = ck_tile::sequence<128, 64, 32, 128, 32, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -252,7 +252,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<128> template <> struct HstuAttentionWithSoftmaxFwdBlockTile<256> { - using type = ck_tile::sequence<128, 64, 256, 32, 256>; + using type = ck_tile::sequence<128, 64, 32, 256, 32, 256>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp index d1c4ebccd4..f3c276c1a4 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp @@ -31,6 +31,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS static constexpr index_t kM0 = HstuAttentionTileSetting::kM0; static constexpr index_t kN0 = HstuAttentionTileSetting::kN0; + static constexpr index_t kN0Sub = HstuAttentionTileSetting::kN0Sub; static constexpr index_t kN1 = HstuAttentionTileSetting::kN1; static constexpr index_t kK1 = HstuAttentionTileSetting::kK1; static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim; @@ -158,17 +159,20 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + constexpr index_t n0_loops = kN0 / kN0Sub; constexpr index_t k1_loops = kN0 / kK1; + static_assert(n0_loops == k1_loops, "n0_loops == k1_loops required by this pipeline"); + constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); // Block GEMM constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - // SaccBlockTile size is [kM0, kK1] + // SaccBlockTile size is [kM0, kN0Sub] // PcompBlockTile size is [kM0, kN0] - using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); using PcompBlockTileType = decltype(cast_tile(CombineSaccBlockTileType{})); @@ -190,7 +194,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); @@ -204,11 +208,11 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS using k_tile_type = decltype(load_tile(k_dram_window)); - statically_indexed_array k_tiles; + statically_indexed_array k_tiles; - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - k_tiles[i_k1] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + k_tiles[i_n0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); }); __builtin_amdgcn_sched_barrier(0); @@ -238,11 +242,11 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); using k_lds_write_window_type = decltype(get_slice_tile( - k_lds_window, sequence<0, 0>{}, sequence{})); + k_lds_window, sequence<0, 0>{}, sequence{})); // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window - using k_lds_read_window_type = - decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); + using k_lds_read_window_type = decltype(get_slice_tile( + k_lds_window, sequence<0, 0>{}, sequence{})); statically_indexed_array k_lds_write_windows; statically_indexed_array k_lds_read_windows; @@ -250,11 +254,12 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { k_lds_write_windows[i_buf] = get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{}); - k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{}); + k_lds_read_windows[i_buf] = + get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); }); // V tile in LDS @@ -376,13 +381,13 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS do { // STAGE 1, Gemm_0 ( S = Q@K ) - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(k_lds_write_windows[i_k1], k_tiles[i_k1], partition_index); + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_write_windows[i_n0], k_tiles[i_n0], partition_index); __builtin_amdgcn_sched_barrier(0x00000001); // load v_tiles used in current iteration - v_tiles[i_k1] = load_tile(v_dram_window); + v_tiles[i_n0] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); __builtin_amdgcn_sched_barrier(0x00000001); @@ -390,7 +395,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS block_sync_lds(); // execute current unroll of gemm_0 - gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); @@ -398,8 +403,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS set_slice_tile(pcomp_tile, tmp_tile, - sequence<0, i_k1 * kK1>{}, - sequence{}); + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); __builtin_amdgcn_sched_barrier(0x00000001); @@ -487,7 +492,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS static_for<0, k1_loops, 1>{}([&](auto i_k1) { // load k_tiles used by next iteration k_tiles[i_k1] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + move_tile_window(k_dram_window, {kN0Sub, 0}); __builtin_amdgcn_sched_barrier(0x00000001); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp index bbceb132de..85e7829552 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp @@ -31,6 +31,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad static constexpr index_t kM0 = HstuAttentionTileSetting::kM0; static constexpr index_t kN0 = HstuAttentionTileSetting::kN0; + static constexpr index_t kN0Sub = HstuAttentionTileSetting::kN0Sub; static constexpr index_t kN1 = HstuAttentionTileSetting::kN1; static constexpr index_t kK1 = HstuAttentionTileSetting::kK1; static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim; @@ -156,17 +157,20 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + constexpr index_t n0_loops = kN0 / kN0Sub; constexpr index_t k1_loops = kN0 / kK1; + static_assert(n0_loops == k1_loops, "n0_loops == k1_loops required by this pipeline"); + constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); // Block GEMM constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - // SaccBlockTile size is [kM0, kK1] + // SaccBlockTile size is [kM0, kN0Sub] // PcompBlockTile size is [kM0, kN0] - using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); using PcompBlockTileType = decltype(cast_tile(CombineSaccBlockTileType{})); @@ -188,7 +192,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); @@ -196,11 +200,11 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad using k_tile_type = decltype(load_tile(k_dram_window)); - statically_indexed_array k_tiles; + statically_indexed_array k_tiles; - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - k_tiles[i_k1] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + k_tiles[i_n0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); }); __builtin_amdgcn_sched_barrier(0); @@ -230,11 +234,11 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); using k_lds_write_window_type = decltype(get_slice_tile( - k_lds_window, sequence<0, 0>{}, sequence{})); + k_lds_window, sequence<0, 0>{}, sequence{})); // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window - using k_lds_read_window_type = - decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); + using k_lds_read_window_type = decltype(get_slice_tile( + k_lds_window, sequence<0, 0>{}, sequence{})); statically_indexed_array k_lds_write_windows; statically_indexed_array k_lds_read_windows; @@ -242,11 +246,12 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { k_lds_write_windows[i_buf] = get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{}); - k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{}); + k_lds_read_windows[i_buf] = + get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); }); // V tile in LDS @@ -344,15 +349,15 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad do { // STAGE 1, Gemm_0 ( S = Q@K ) - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(k_lds_write_windows[number{}], - k_tiles[i_k1], + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_write_windows[number{}], + k_tiles[i_n0], partition_index); __builtin_amdgcn_sched_barrier(0x00000001); // load v_tiles used in current iteration - v_tiles[i_k1] = load_tile(v_dram_window); + v_tiles[i_n0] = load_tile(v_dram_window); move_tile_window(v_dram_window, {kK1, 0}); __builtin_amdgcn_sched_barrier(0x00000001); @@ -360,7 +365,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad block_sync_lds(); // execute current unroll of gemm_0 - gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); @@ -368,8 +373,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad set_slice_tile(pcomp_tile, tmp_tile, - sequence<0, i_k1 * kK1>{}, - sequence{}); + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); __builtin_amdgcn_sched_barrier(0x00000001); @@ -449,7 +454,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad // load k_tiles used by next iteration k_tiles[i_k1] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + move_tile_window(k_dram_window, {kN0Sub, 0}); __builtin_amdgcn_sched_barrier(0x00000001); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp index d19dfc91c2..c61ef3912d 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp @@ -36,7 +36,7 @@ struct HstuAttentionFwdTileSettingClass using Gemm1BlockWarps = remove_cvref_t; using Gemm1WarpTile = remove_cvref_t; - static_assert(BlockTile::size() == 5, "Check failed!"); + static_assert(BlockTile::size() == 6, "Check failed!"); static_assert(Gemm0BlockWarps::size() == 3, "Check failed!"); static_assert(Gemm0WarpTile::size() == 3, "Check failed!"); static_assert(Gemm1BlockWarps::size() == 3, "Check failed!"); @@ -50,12 +50,13 @@ struct HstuAttentionFwdTileSettingClass static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); - static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen - static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen - static constexpr index_t kN1 = BlockTile::at(number<2>{}); // tile size along v head_dim - static constexpr index_t kK1 = BlockTile::at(number<3>{}); // tile size along kv gemm unroll + static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen + static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen + static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // tile size along k seqlen + static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim + static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll static constexpr index_t kQKHeaddim = - BlockTile::at(number<4>{}); // total length of K0, used for pipeline that need load Q at + BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at // once (or repeately load Q as a whole tile) static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index 7858e28d04..538ec053cf 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -31,6 +31,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS static constexpr index_t kM0 = HstuAttentionTileSetting::kM0; static constexpr index_t kN0 = HstuAttentionTileSetting::kN0; + static constexpr index_t kN0Sub = HstuAttentionTileSetting::kN0Sub; static constexpr index_t kN1 = HstuAttentionTileSetting::kN1; static constexpr index_t kK1 = HstuAttentionTileSetting::kK1; static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim; @@ -160,8 +161,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + constexpr index_t n0_loops = kN0 / kN0Sub; constexpr index_t k1_loops = kN0 / kK1; + static_assert(n0_loops >= k1_loops, "n0_loops >= k1_loops required by this pipeline"); static_assert(k1_loops >= 2, "k1_loops >= 2 required due to pre-storing two v_tiles to Lds"); @@ -171,9 +174,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - // SaccBlockTile size is [kM0, kK1] + // SaccBlockTile size is [kM0, kN0Sub] // PcompBlockTile size is [kM0, kN0] - using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); using PcompBlockTileType = decltype(cast_tile(CombineSaccBlockTileType{})); @@ -205,7 +208,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); @@ -228,7 +231,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) { k_tiles[i_k1] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + move_tile_window(k_dram_window, {kN0Sub, 0}); }); __builtin_amdgcn_sched_barrier(0); @@ -258,11 +261,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); using k_lds_write_window_type = decltype(get_slice_tile( - k_lds_window, sequence<0, 0>{}, sequence{})); + k_lds_window, sequence<0, 0>{}, sequence{})); // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window - using k_lds_read_window_type = - decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); + using k_lds_read_window_type = decltype(get_slice_tile( + k_lds_window, sequence<0, 0>{}, sequence{})); statically_indexed_array k_lds_write_windows; statically_indexed_array k_lds_read_windows; @@ -270,11 +273,12 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { k_lds_write_windows[i_buf] = get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{}); - k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{}); + k_lds_read_windows[i_buf] = + get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); }); // V tile in LDS @@ -396,23 +400,27 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS do { // STAGE 1, Gemm_0 ( S = Q@K ) - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(k_lds_write_windows[number{}], - k_tiles[number{}], + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_write_windows[number{}], + k_tiles[number{}], partition_index); __builtin_amdgcn_sched_barrier(0x00000001); - if constexpr(i_k1 < k1_loops - NumPrefetchK) + if constexpr(i_n0 < n0_loops - NumPrefetchK) { - k_tiles[number{}] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); } else { - // load v_tiles used in current iteration - v_tiles[number{}] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); + if constexpr(i_n0 - (n0_loops - NumPrefetchK) < k1_loops) + { + // load v_tiles used in current iteration + v_tiles[number{}] = + load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + } }; __builtin_amdgcn_sched_barrier(0x00000001); @@ -420,7 +428,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS block_sync_lds(); // execute current unroll of gemm_0 - gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); @@ -428,8 +436,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS set_slice_tile(pcomp_tile, tmp_tile, - sequence<0, i_k1 * kK1>{}, - sequence{}); + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); // STAGE 2, scale_s, add bias, mask, siLU @@ -511,7 +519,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS __builtin_amdgcn_sched_barrier(0x00000001); - static_for{}([&](auto i_k1) { + static_for{}([&](auto i_k1) { // load v_tiles used in current iteration v_tiles[i_k1] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); @@ -600,7 +608,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS { // load k_tiles used by next iteration k_tiles[i_k1] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + move_tile_window(k_dram_window, {kN0Sub, 0}); }; __builtin_amdgcn_sched_barrier(0x00000001); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index bf2d2173e8..d8e96289bd 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -31,6 +31,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad static constexpr index_t kM0 = HstuAttentionTileSetting::kM0; static constexpr index_t kN0 = HstuAttentionTileSetting::kN0; + static constexpr index_t kN0Sub = HstuAttentionTileSetting::kN0Sub; static constexpr index_t kN1 = HstuAttentionTileSetting::kN1; static constexpr index_t kK1 = HstuAttentionTileSetting::kK1; static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim; @@ -158,8 +159,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + constexpr index_t n0_loops = kN0 / kN0Sub; constexpr index_t k1_loops = kN0 / kK1; + static_assert(n0_loops >= k1_loops, "n0_loops >= k1_loops required by this pipeline"); static_assert(k1_loops >= 2, "k1_loops >= 2 required due to pre-storing two v_tiles to Lds"); @@ -169,9 +172,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - // SaccBlockTile size is [kM0, kK1] + // SaccBlockTile size is [kM0, kN0Sub] // PcompBlockTile size is [kM0, kN0] - using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); using PcompBlockTileType = decltype(cast_tile(CombineSaccBlockTileType{})); @@ -203,7 +206,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); @@ -223,7 +226,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) { k_tiles[i_k1] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + move_tile_window(k_dram_window, {kN0Sub, 0}); }); __builtin_amdgcn_sched_barrier(0); @@ -252,11 +255,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); using k_lds_write_window_type = decltype(get_slice_tile( - k_lds_window, sequence<0, 0>{}, sequence{})); + k_lds_window, sequence<0, 0>{}, sequence{})); // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window - using k_lds_read_window_type = - decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); + using k_lds_read_window_type = decltype(get_slice_tile( + k_lds_window, sequence<0, 0>{}, sequence{})); statically_indexed_array k_lds_write_windows; statically_indexed_array k_lds_read_windows; @@ -264,11 +267,12 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { k_lds_write_windows[i_buf] = get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{}); - k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{}); + k_lds_read_windows[i_buf] = + get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); }); // V tile in LDS @@ -366,23 +370,27 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad do { // STAGE 1, Gemm_0 ( S = Q@K ) - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(k_lds_write_windows[number{}], - k_tiles[number{}], + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_write_windows[number{}], + k_tiles[number{}], partition_index); __builtin_amdgcn_sched_barrier(0x00000001); - if constexpr(i_k1 < k1_loops - NumPrefetchK) + if constexpr(i_n0 < n0_loops - NumPrefetchK) { - k_tiles[number{}] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); } else { - // load v_tiles used in current iteration - v_tiles[number{}] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {kK1, 0}); + if constexpr(i_n0 - (n0_loops - NumPrefetchK) < k1_loops) + { + // load v_tiles used in current iteration + v_tiles[number{}] = + load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + } }; __builtin_amdgcn_sched_barrier(0x00000001); @@ -390,7 +398,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad block_sync_lds(); // execute current unroll of gemm_0 - gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); @@ -398,8 +406,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad set_slice_tile(pcomp_tile, tmp_tile, - sequence<0, i_k1 * kK1>{}, - sequence{}); + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); __builtin_amdgcn_sched_barrier(0x00000001); @@ -477,7 +485,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad __builtin_amdgcn_sched_barrier(0x00000001); - static_for{}([&](auto i_k1) { + static_for{}([&](auto i_k1) { // load v_tiles used in current iteration v_tiles[i_k1] = load_tile(v_dram_window); move_tile_window(v_dram_window, {kK1, 0}); @@ -572,7 +580,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad { // load k_tiles used by next iteration k_tiles[i_k1] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + move_tile_window(k_dram_window, {kN0Sub, 0}); }; __builtin_amdgcn_sched_barrier(0x00000001);