mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Add support for preparing lse_dram_window in hstu fwd kernel
This commit is contained in:
@@ -41,6 +41,8 @@ struct HstuAttentionFwdKernel
|
||||
using BiasDataType =
|
||||
ck_tile::remove_cvref_t<typename HstuAttentionPipeline::Problem::BiasDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::Problem::ODataType>;
|
||||
using CompDataType =
|
||||
ck_tile::remove_cvref_t<typename HstuAttentionPipeline::Problem::CompDataType>;
|
||||
|
||||
static constexpr bool kIsCrossAttention = HstuAttentionPipeline::Problem::kIsCrossAttention;
|
||||
static constexpr bool kUseGroup = HstuAttentionPipeline::Problem::kUseGroup;
|
||||
@@ -701,6 +703,7 @@ struct HstuAttentionFwdKernel
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
@@ -717,6 +720,10 @@ struct HstuAttentionFwdKernel
|
||||
batch_offset_bias = query_start * kargs.seq_stride_bias;
|
||||
}
|
||||
batch_offset_o = query_start * kargs.seq_stride_o;
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = query_start * kargs.seq_stride_lse;
|
||||
}
|
||||
|
||||
kargs.seqlen_q =
|
||||
kargs.seq_q_offsets_ptr[i_batch + 1] - kargs.seq_q_offsets_ptr[i_batch];
|
||||
@@ -747,6 +754,10 @@ struct HstuAttentionFwdKernel
|
||||
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
}
|
||||
}
|
||||
|
||||
int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch];
|
||||
@@ -916,6 +927,35 @@ struct HstuAttentionFwdKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
|
||||
constexpr auto lse_dram_window_lengths =
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{});
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
CompDataType* lse_ptr =
|
||||
reinterpret_cast<CompDataType*>(kargs.lse_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
|
||||
|
||||
const auto lse_dram = [&]() {
|
||||
const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
lse_ptr,
|
||||
make_tuple(seqlen_q_in_ctrl),
|
||||
make_tuple(kargs.seq_stride_lse),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
lse_dram_naive, lse_dram_window_lengths, sequence<false>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(lse_dram_window_lengths);
|
||||
}
|
||||
}();
|
||||
|
||||
auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
@@ -985,13 +1025,11 @@ struct HstuAttentionFwdKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
auto null_tile_window = ck_tile::make_null_tile_window(ck_tile::make_tuple());
|
||||
|
||||
return HstuAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
null_tile_window,
|
||||
lse_dram_window,
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
@@ -1040,12 +1078,11 @@ struct HstuAttentionFwdKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
auto null_tile_window = ck_tile::make_null_tile_window(ck_tile::make_tuple());
|
||||
return HstuAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
null_tile_window,
|
||||
lse_dram_window,
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
|
||||
@@ -43,14 +43,14 @@ struct HstuAttentionNoGroupFwdParams
|
||||
ck_tile::index_t seq_stride_v;
|
||||
ck_tile::index_t seq_stride_bias;
|
||||
ck_tile::index_t seq_stride_o;
|
||||
ck_tile::index_t seq_stride_lse;
|
||||
ck_tile::index_t seq_stride_lse; // not needed if lse layout is [nhead, seqlen_q]
|
||||
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_lse; // not needed if lse layout is [seqlen_q, nhead]
|
||||
|
||||
// batched mode only parameters
|
||||
ck_tile::index_t batch_stride_q;
|
||||
@@ -106,14 +106,14 @@ struct HstuAttentionGroupFwdParams
|
||||
ck_tile::index_t seq_stride_v;
|
||||
ck_tile::index_t seq_stride_bias;
|
||||
ck_tile::index_t seq_stride_o;
|
||||
ck_tile::index_t seq_stride_lse;
|
||||
ck_tile::index_t seq_stride_lse; // not needed if lse layout is [nhead, seqlen_q]
|
||||
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_lse; // not needed if lse layout is [seqlen_q, nhead]
|
||||
|
||||
const void* num_targets_ptr;
|
||||
|
||||
|
||||
@@ -81,7 +81,6 @@ struct HstuAttentionFwdPipelineProblem
|
||||
using CompDataType = remove_cvref_t<CompDataType_>;
|
||||
using BiasDataType = remove_cvref_t<BiasDataType_>;
|
||||
|
||||
// to be compatible with ck_tile existing policy codes
|
||||
using OaccDataType = GemmAccDataType;
|
||||
using PDataType = QKVDataType;
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEaccDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEaccElementFunction,
|
||||
@@ -134,8 +134,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0 tile
|
||||
const LSEaccElementFunction& lse_acc_element_func,
|
||||
LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
const LSEaccElementFunction& lse_or_lse_acc_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
@@ -204,15 +204,15 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
clear_tile(o_acc);
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
if constexpr(!is_null_tile_window_v<LSEaccDramBlockWindowTmp>)
|
||||
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindowTmp>)
|
||||
{
|
||||
auto lse_acc =
|
||||
auto lse_or_lse_acc =
|
||||
make_static_distributed_tensor<CompDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse_acc, -numeric<CompDataType>::infinity());
|
||||
set_tile(lse_or_lse_acc, -numeric<CompDataType>::infinity());
|
||||
|
||||
store_tile(lse_acc_dram_block_window_tmp,
|
||||
tile_elementwise_in(lse_acc_element_func, lse_acc));
|
||||
store_tile(lse_or_lse_acc_dram_block_window,
|
||||
tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc));
|
||||
}
|
||||
|
||||
return o_acc;
|
||||
@@ -598,19 +598,22 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
};
|
||||
} while(seqlen_k_curr < seqlen_k_end);
|
||||
|
||||
if constexpr(!is_null_tile_window_v<LSEaccDramBlockWindowTmp>)
|
||||
// if pipeline is called from splitkv_kernel, the window shall not be null;
|
||||
// if pipeline is called from non-splitkv kernel, the window is null if kStoreLSE is false
|
||||
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindowTmp>)
|
||||
{
|
||||
// store lse acc
|
||||
auto lse_acc = make_static_distributed_tensor<CompDataType>(m.get_tile_distribution());
|
||||
// store lse_or_lse_acc
|
||||
auto lse_or_lse_acc =
|
||||
make_static_distributed_tensor<CompDataType>(m.get_tile_distribution());
|
||||
|
||||
constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
|
||||
sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
|
||||
constexpr auto lse_or_lse_acc_spans = decltype(lse_or_lse_acc)::get_distributed_spans();
|
||||
sweep_tile_span(lse_or_lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
lse_or_lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
|
||||
});
|
||||
|
||||
store_tile(lse_acc_dram_block_window_tmp,
|
||||
tile_elementwise_in(lse_acc_element_func, lse_acc));
|
||||
store_tile(lse_or_lse_acc_dram_block_window,
|
||||
tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc));
|
||||
}
|
||||
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
@@ -638,14 +641,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEaccDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0 tile
|
||||
LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask mask,
|
||||
@@ -660,7 +663,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
v_dram_block_window_tmp,
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
lse_acc_dram_block_window_tmp,
|
||||
lse_or_lse_acc_dram_block_window,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
|
||||
@@ -119,7 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEaccDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEaccElementFunction,
|
||||
@@ -134,8 +134,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0 tile
|
||||
const LSEaccElementFunction& lse_acc_element_func,
|
||||
LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
const LSEaccElementFunction& lse_or_lse_acc_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
@@ -205,15 +205,15 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
clear_tile(o_acc);
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
if constexpr(!is_null_tile_window_v<LSEaccDramBlockWindowTmp>)
|
||||
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindowTmp>)
|
||||
{
|
||||
auto lse_acc =
|
||||
auto lse_or_lse_acc =
|
||||
make_static_distributed_tensor<CompDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse_acc, -numeric<CompDataType>::infinity());
|
||||
set_tile(lse_or_lse_acc, -numeric<CompDataType>::infinity());
|
||||
|
||||
store_tile(lse_acc_dram_block_window_tmp,
|
||||
tile_elementwise_in(lse_acc_element_func, lse_acc));
|
||||
store_tile(lse_or_lse_acc_dram_block_window,
|
||||
tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc));
|
||||
}
|
||||
|
||||
return o_acc;
|
||||
@@ -604,19 +604,22 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
};
|
||||
} while(seqlen_k_curr < seqlen_k_end);
|
||||
|
||||
if constexpr(!is_null_tile_window_v<LSEaccDramBlockWindowTmp>)
|
||||
// if pipeline is called from splitkv_kernel, the window shall not be null;
|
||||
// if pipeline is called from non-splitkv kernel, the window is null if kStoreLSE is false
|
||||
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindowTmp>)
|
||||
{
|
||||
// store lse acc
|
||||
auto lse_acc = make_static_distributed_tensor<CompDataType>(m.get_tile_distribution());
|
||||
// store lse or lse_acc
|
||||
auto lse_or_lse_acc =
|
||||
make_static_distributed_tensor<CompDataType>(m.get_tile_distribution());
|
||||
|
||||
constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
|
||||
sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
|
||||
constexpr auto lse_or_lse_acc_spans = decltype(lse_or_lse_acc)::get_distributed_spans();
|
||||
sweep_tile_span(lse_or_lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
lse_or_lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
|
||||
});
|
||||
|
||||
store_tile(lse_acc_dram_block_window_tmp,
|
||||
tile_elementwise_in(lse_acc_element_func, lse_acc));
|
||||
store_tile(lse_or_lse_acc_dram_block_window,
|
||||
tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc));
|
||||
}
|
||||
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
@@ -644,15 +647,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEaccDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEaccDramBlockWindowTmp&
|
||||
lse_acc_dram_block_window_tmp, // M0 tile //
|
||||
LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask mask,
|
||||
@@ -667,7 +669,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
v_dram_block_window_tmp,
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
lse_acc_dram_block_window_tmp,
|
||||
lse_or_lse_acc_dram_block_window,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
|
||||
Reference in New Issue
Block a user