Simplifying the codes with regard to k_lds_wite_windows and k_lds_read_windows in the pipelines

This commit is contained in:
Qianfeng Zhang
2025-12-01 14:34:53 +00:00
parent c1817464be
commit 7234b2fc1a
4 changed files with 28 additions and 40 deletions

View File

@@ -233,28 +233,25 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_write_window = make_tile_window(
auto k_lds_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
auto k_lds_read_window =
make_tile_window(k_lds, make_tuple(number<kK1>{}, number<kQKHeaddim>{}), {0, 0});
using k_lds_write_window_type = decltype(get_slice_tile(
k_lds_write_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
k_lds_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
using k_lds_read_window_type = decltype(get_slice_tile(
k_lds_read_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
using k_lds_read_window_type =
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
k_lds_write_windows[i_buf] =
get_slice_tile(k_lds_write_window,
get_slice_tile(k_lds_window,
sequence<i_buf * kK1, 0>{},
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window,
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window,
sequence<i_buf * kK1, 0>{},
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
});

View File

@@ -226,28 +226,25 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_write_window = make_tile_window(
auto k_lds_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
auto k_lds_read_window =
make_tile_window(k_lds, make_tuple(number<kK1>{}, number<kQKHeaddim>{}), {0, 0});
using k_lds_write_window_type = decltype(get_slice_tile(
k_lds_write_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
k_lds_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
using k_lds_read_window_type = decltype(get_slice_tile(
k_lds_read_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
using k_lds_read_window_type =
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
k_lds_write_windows[i_buf] =
get_slice_tile(k_lds_write_window,
get_slice_tile(k_lds_window,
sequence<i_buf * kK1, 0>{},
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window,
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window,
sequence<i_buf * kK1, 0>{},
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
});

View File

@@ -253,28 +253,25 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_write_window = make_tile_window(
auto k_lds_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
auto k_lds_read_window =
make_tile_window(k_lds, make_tuple(number<kK1>{}, number<kQKHeaddim>{}), {0, 0});
using k_lds_write_window_type = decltype(get_slice_tile(
k_lds_write_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
k_lds_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
using k_lds_read_window_type = decltype(get_slice_tile(
k_lds_read_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
using k_lds_read_window_type =
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
k_lds_write_windows[i_buf] =
get_slice_tile(k_lds_write_window,
get_slice_tile(k_lds_window,
sequence<i_buf * kK1, 0>{},
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window,
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window,
sequence<i_buf * kK1, 0>{},
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
});

View File

@@ -248,28 +248,25 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_write_window = make_tile_window(
auto k_lds_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
auto k_lds_read_window =
make_tile_window(k_lds, make_tuple(number<kK1>{}, number<kQKHeaddim>{}), {0, 0});
using k_lds_write_window_type = decltype(get_slice_tile(
k_lds_write_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
k_lds_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
using k_lds_read_window_type = decltype(get_slice_tile(
k_lds_read_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
using k_lds_read_window_type =
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
k_lds_write_windows[i_buf] =
get_slice_tile(k_lds_write_window,
get_slice_tile(k_lds_window,
sequence<i_buf * kK1, 0>{},
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window,
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window,
sequence<i_buf * kK1, 0>{},
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
});