mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
Simplifying the codes with regard to k_lds_wite_windows and k_lds_read_windows in the pipelines
This commit is contained in:
@@ -233,28 +233,25 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_write_window = make_tile_window(
|
||||
auto k_lds_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kK1>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_write_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
|
||||
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
k_lds_read_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
using k_lds_read_window_type =
|
||||
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_write_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window,
|
||||
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
@@ -226,28 +226,25 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_write_window = make_tile_window(
|
||||
auto k_lds_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kK1>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_write_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
|
||||
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
k_lds_read_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
using k_lds_read_window_type =
|
||||
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_write_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window,
|
||||
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
@@ -253,28 +253,25 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_write_window = make_tile_window(
|
||||
auto k_lds_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kK1>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_write_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
|
||||
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
k_lds_read_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
using k_lds_read_window_type =
|
||||
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_write_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window,
|
||||
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
@@ -248,28 +248,25 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_write_window = make_tile_window(
|
||||
auto k_lds_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kK1>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_write_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
|
||||
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
k_lds_read_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
using k_lds_read_window_type =
|
||||
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_write_window,
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window,
|
||||
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kK1, 0>{},
|
||||
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user