diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index b97fd9d550..86b5da9af2 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -208,9 +208,10 @@ struct HstuAttentionFwdPipelineQRKSVS q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); auto q_lds_write_window = make_tile_window( q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); + // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window auto q_lds_read_window = make_tile_window(q_lds, - Policy::template MakeQLdsBlockDescriptor().get_lengths(), + make_tuple(number{}, number{}), {0, 0}, Policy::template MakeQRegSingleRepMTileDistribution()); @@ -218,18 +219,30 @@ struct HstuAttentionFwdPipelineQRKSVS QKVDataType* k_lds_ptr = static_cast(smem_ptr); auto k_lds = make_tensor_view( k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); - auto k_lds_window = make_tile_window( + auto k_lds_write_window = make_tile_window( k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); - using k_lds_window_type = - decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); + // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window + auto k_lds_read_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); - statically_indexed_array k_lds_windows; + using k_lds_write_window_type = decltype(get_slice_tile( + k_lds_write_window, sequence<0, 0>{}, sequence{})); + + using k_lds_read_window_type = decltype(get_slice_tile( + k_lds_read_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array k_lds_write_windows; + statically_indexed_array k_lds_read_windows; static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { - k_lds_windows[i_buf] = get_slice_tile(k_lds_window, - sequence{}, - sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); + k_lds_write_windows[i_buf] = + get_slice_tile(k_lds_write_window, + sequence{}, + sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{}); + k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window, + sequence{}, + sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); }); // V tile in LDS @@ -356,7 +369,7 @@ struct HstuAttentionFwdPipelineQRKSVS v_tile_type v_tile; - store_tile(k_lds_windows[number<0>{}], tile_elementwise_in(k_element_func, k_tile)); + store_tile(k_lds_write_windows[number<0>{}], tile_elementwise_in(k_element_func, k_tile)); do { @@ -380,7 +393,7 @@ struct HstuAttentionFwdPipelineQRKSVS block_sync_lds(); // execute current unroll of gemm_0 - gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); + gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); if constexpr(V_VMEM_LOAD_INST + K_VMEM_LOAD_INST < K_LDS_READ_INST) { @@ -448,7 +461,7 @@ struct HstuAttentionFwdPipelineQRKSVS block_sync_lds(); // execute current unroll of gemm_0 - gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); + gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); }; sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); @@ -533,7 +546,7 @@ struct HstuAttentionFwdPipelineQRKSVS __builtin_amdgcn_s_barrier(); }; - store_tile(k_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}], + store_tile(k_lds_write_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}], tile_elementwise_in(k_element_func, k_tile)); __builtin_amdgcn_sched_barrier(0x00000001); @@ -547,7 +560,7 @@ struct HstuAttentionFwdPipelineQRKSVS __builtin_amdgcn_s_barrier(); }; - store_tile(k_lds_windows[number<0>{}], + store_tile(k_lds_write_windows[number<0>{}], tile_elementwise_in(k_element_func, k_tile)); } }); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index d1169f7973..8bfb1fa8b3 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -103,7 +103,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; constexpr index_t MaxVectorSize = 16 / sizeof(QDataType); constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; @@ -127,7 +127,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1; - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; @@ -170,7 +170,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize() { constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1; - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); @@ -212,7 +212,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() { constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; constexpr index_t kKPack = GetSmemKPackQ(); constexpr index_t kKVector = GetAlignmentQ(); @@ -301,7 +301,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType); @@ -331,7 +331,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy { constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers(); constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1; - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); @@ -438,7 +438,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1; - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim; constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType); constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;