From 7234b2fc1a9a791121b1182334ca89da68374a75 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 1 Dec 2025 14:34:53 +0000 Subject: [PATCH] Simplifying the codes with regard to k_lds_wite_windows and k_lds_read_windows in the pipelines --- .../hstu_attention_no_softmax_fwd_pipeline.hpp | 17 +++++++---------- ...attention_no_softmax_fwd_trload_pipeline.hpp | 17 +++++++---------- ...hstu_attention_with_softmax_fwd_pipeline.hpp | 17 +++++++---------- ...tention_with_softmax_fwd_trload_pipeline.hpp | 17 +++++++---------- 4 files changed, 28 insertions(+), 40 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp index 51dddcde5b..dc17803a35 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp @@ -233,28 +233,25 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS QKVDataType* k_lds_ptr = static_cast(smem_ptr); auto k_lds = make_tensor_view( k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); - auto k_lds_write_window = make_tile_window( + auto k_lds_window = make_tile_window( k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); - // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window - auto k_lds_read_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); - using k_lds_write_window_type = decltype(get_slice_tile( - k_lds_write_window, sequence<0, 0>{}, sequence{})); + k_lds_window, sequence<0, 0>{}, sequence{})); - using k_lds_read_window_type = decltype(get_slice_tile( - k_lds_read_window, sequence<0, 0>{}, sequence{})); + // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window + using k_lds_read_window_type = + decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); statically_indexed_array k_lds_write_windows; statically_indexed_array k_lds_read_windows; static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { k_lds_write_windows[i_buf] = - get_slice_tile(k_lds_write_window, + get_slice_tile(k_lds_window, sequence{}, sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{}); - k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window, + k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window, sequence{}, sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); }); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp index 9547ec21de..c4ffd580a9 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp @@ -226,28 +226,25 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad QKVDataType* k_lds_ptr = static_cast(smem_ptr); auto k_lds = make_tensor_view( k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); - auto k_lds_write_window = make_tile_window( + auto k_lds_window = make_tile_window( k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); - // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window - auto k_lds_read_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); - using k_lds_write_window_type = decltype(get_slice_tile( - k_lds_write_window, sequence<0, 0>{}, sequence{})); + k_lds_window, sequence<0, 0>{}, sequence{})); - using k_lds_read_window_type = decltype(get_slice_tile( - k_lds_read_window, sequence<0, 0>{}, sequence{})); + // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window + using k_lds_read_window_type = + decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); statically_indexed_array k_lds_write_windows; statically_indexed_array k_lds_read_windows; static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { k_lds_write_windows[i_buf] = - get_slice_tile(k_lds_write_window, + get_slice_tile(k_lds_window, sequence{}, sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{}); - k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window, + k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window, sequence{}, sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); }); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index e8fa211222..85aff422a1 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -253,28 +253,25 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS QKVDataType* k_lds_ptr = static_cast(smem_ptr); auto k_lds = make_tensor_view( k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); - auto k_lds_write_window = make_tile_window( + auto k_lds_window = make_tile_window( k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); - // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window - auto k_lds_read_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); - using k_lds_write_window_type = decltype(get_slice_tile( - k_lds_write_window, sequence<0, 0>{}, sequence{})); + k_lds_window, sequence<0, 0>{}, sequence{})); - using k_lds_read_window_type = decltype(get_slice_tile( - k_lds_read_window, sequence<0, 0>{}, sequence{})); + // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window + using k_lds_read_window_type = + decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); statically_indexed_array k_lds_write_windows; statically_indexed_array k_lds_read_windows; static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { k_lds_write_windows[i_buf] = - get_slice_tile(k_lds_write_window, + get_slice_tile(k_lds_window, sequence{}, sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{}); - k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window, + k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window, sequence{}, sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); }); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 9af1a1b610..c089c915c9 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -248,28 +248,25 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad QKVDataType* k_lds_ptr = static_cast(smem_ptr); auto k_lds = make_tensor_view( k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); - auto k_lds_write_window = make_tile_window( + auto k_lds_window = make_tile_window( k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); - // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window - auto k_lds_read_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); - using k_lds_write_window_type = decltype(get_slice_tile( - k_lds_write_window, sequence<0, 0>{}, sequence{})); + k_lds_window, sequence<0, 0>{}, sequence{})); - using k_lds_read_window_type = decltype(get_slice_tile( - k_lds_read_window, sequence<0, 0>{}, sequence{})); + // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window + using k_lds_read_window_type = + decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); statically_indexed_array k_lds_write_windows; statically_indexed_array k_lds_read_windows; static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { k_lds_write_windows[i_buf] = - get_slice_tile(k_lds_write_window, + get_slice_tile(k_lds_window, sequence{}, sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{}); - k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window, + k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window, sequence{}, sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); });