mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
Update and fix for leeked changes and make the scripts be able to test/benchmark kStoreLSE cases
This commit is contained in:
@@ -415,6 +415,8 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged)
|
||||
|
||||
int batches_for_alloc = is_jagged ? 1 : num_batch;
|
||||
|
||||
bool store_lse = (is_training & use_softmax);
|
||||
|
||||
ck_tile::HostTensor<InOutDataType> q_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_qk});
|
||||
ck_tile::HostTensor<InOutDataType> k_host(
|
||||
@@ -424,9 +426,8 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged)
|
||||
ck_tile::HostTensor<InOutDataType> o_host_ref(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_v});
|
||||
ck_tile::HostTensor<CompDataType> lse_host_ref(
|
||||
(is_training && use_softmax)
|
||||
? std::array<ck_tile::index_t, 3>{batches_for_alloc, phy_seqlen_q, num_head}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
store_lse ? std::array<ck_tile::index_t, 3>{batches_for_alloc, phy_seqlen_q, num_head}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
|
||||
ck_tile::HostTensor<int8_t> mask_host(
|
||||
save_mask
|
||||
@@ -593,7 +594,6 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged)
|
||||
using GemmAccDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType;
|
||||
|
||||
BOOL_SWITCH_2(is_jagged, kIsJagged, use_causal, kUseCausal, [&] {
|
||||
bool store_lse = (is_training && use_softmax);
|
||||
ck_tile::reference_no_group_hstu_attention_fwd<InOutDataType,
|
||||
GemmAccDataType,
|
||||
CompDataType,
|
||||
@@ -640,7 +640,7 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged)
|
||||
res = ck_tile::check_err(
|
||||
o_host, o_host_ref, std::string("hstu_attention output error"), rtol, atol);
|
||||
|
||||
if(is_training && use_softmax)
|
||||
if(store_lse)
|
||||
{
|
||||
ck_tile::HostTensor<CompDataType> lse_host(
|
||||
std::array<ck_tile::index_t, 3>{batches_for_alloc, phy_seqlen_q, num_head});
|
||||
@@ -913,6 +913,8 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
|
||||
|
||||
int batches_for_alloc = 1;
|
||||
|
||||
bool store_lse = (is_training & use_softmax);
|
||||
|
||||
ck_tile::HostTensor<InOutDataType> q_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_qk});
|
||||
ck_tile::HostTensor<InOutDataType> k_host(
|
||||
@@ -922,9 +924,8 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
|
||||
ck_tile::HostTensor<InOutDataType> o_host_ref(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_v});
|
||||
ck_tile::HostTensor<CompDataType> lse_host_ref(
|
||||
(is_training && use_softmax)
|
||||
? std::array<ck_tile::index_t, 3>{batches_for_alloc, phy_seqlen_q, num_head}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
store_lse ? std::array<ck_tile::index_t, 3>{batches_for_alloc, phy_seqlen_q, num_head}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
|
||||
ck_tile::HostTensor<int8_t> mask_host(save_mask
|
||||
? std::array<ck_tile::index_t, 4>{num_batch,
|
||||
@@ -1054,7 +1055,6 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
|
||||
using GemmAccDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType;
|
||||
|
||||
BOOL_SWITCH(use_causal, kUseCausal, [&] {
|
||||
bool store_lse = (is_training && use_softmax);
|
||||
ck_tile::reference_group_hstu_attention_fwd<
|
||||
InOutDataType,
|
||||
GemmAccDataType,
|
||||
@@ -1103,7 +1103,7 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
|
||||
res = ck_tile::check_err(
|
||||
o_host, o_host_ref, std::string("hstu_attention output error"), rtol, atol);
|
||||
|
||||
if(is_training && use_softmax)
|
||||
if(store_lse)
|
||||
{
|
||||
ck_tile::HostTensor<CompDataType> lse_host(
|
||||
std::array<ck_tile::index_t, 3>{batches_for_alloc, phy_seqlen_q, num_head});
|
||||
|
||||
@@ -79,7 +79,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
ODataType,
|
||||
false /* kIsJagged */,
|
||||
kUseSoftmax,
|
||||
false, // kStoreLSE
|
||||
kStoreLSE,
|
||||
HstuAttentionCombineTileSetting,
|
||||
kMaxSplits>;
|
||||
|
||||
|
||||
@@ -281,6 +281,7 @@ struct HstuAttentionFwdSplitKVCombineKernel
|
||||
long_index_t batch_offset_o_acc = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
@@ -299,6 +300,11 @@ struct HstuAttentionFwdSplitKVCombineKernel
|
||||
{
|
||||
batch_offset_lse_acc = query_start * kargs.num_head * kargs.num_splits;
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = query_start * kargs.seq_stride_lse;
|
||||
}
|
||||
|
||||
kargs.seqlen_q =
|
||||
kargs.seq_q_offsets_ptr[i_batch + 1] - kargs.seq_q_offsets_ptr[i_batch];
|
||||
}
|
||||
@@ -317,6 +323,10 @@ struct HstuAttentionFwdSplitKVCombineKernel
|
||||
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.seqlen_q *
|
||||
kargs.num_head * kargs.num_splits;
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
}
|
||||
}
|
||||
|
||||
index_t i_m0;
|
||||
@@ -394,8 +404,40 @@ struct HstuAttentionFwdSplitKVCombineKernel
|
||||
number<HstuAttentionPipeline::kMaxSplits>{}),
|
||||
{i_m0, 0});
|
||||
|
||||
auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
|
||||
constexpr auto lse_dram_window_lengths =
|
||||
make_tuple(number<HstuAttentionPipeline::kM>{});
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
LSEDataType* lse_ptr =
|
||||
reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
|
||||
batch_offset_lse;
|
||||
|
||||
const auto lse_dram = [&]() {
|
||||
const auto lse_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
lse_ptr,
|
||||
make_tuple(kargs.seqlen_q),
|
||||
make_tuple(kargs.seq_stride_lse),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
lse_dram_naive, lse_dram_window_lengths, sequence<false>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(lse_dram_window_lengths);
|
||||
}
|
||||
}();
|
||||
|
||||
return HstuAttentionPipeline{}(lse_acc_dram_window,
|
||||
o_acc_dram_window,
|
||||
lse_dram_window,
|
||||
kargs.hdim_v,
|
||||
kargs.num_splits,
|
||||
smem_ptr);
|
||||
|
||||
@@ -119,7 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindow,
|
||||
typename QElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEaccElementFunction,
|
||||
@@ -134,7 +134,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
const LSEaccElementFunction& lse_or_lse_acc_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
@@ -204,7 +204,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
clear_tile(o_acc);
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindowTmp>)
|
||||
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindow>)
|
||||
{
|
||||
auto lse_or_lse_acc =
|
||||
make_static_distributed_tensor<CompDataType>(m.get_tile_distribution());
|
||||
@@ -600,7 +600,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
// if pipeline is called from splitkv_kernel, the window shall not be null;
|
||||
// if pipeline is called from non-splitkv kernel, the window is null if kStoreLSE is false
|
||||
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindowTmp>)
|
||||
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindow>)
|
||||
{
|
||||
// store lse_or_lse_acc
|
||||
auto lse_or_lse_acc =
|
||||
@@ -641,14 +641,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindow,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask mask,
|
||||
|
||||
@@ -63,13 +63,17 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline
|
||||
|
||||
template <typename LSEaccDramBlockWindowTmp,
|
||||
typename OAccDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindow,
|
||||
typename OAccElementFunction,
|
||||
typename LSEaccElementFunction>
|
||||
typename LSEaccElementFunction,
|
||||
typename LSEElementFunction>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // kM tile
|
||||
const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile
|
||||
LSEDramBlockWindow& lse_dram_block_window, // kM tile
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
const LSEaccElementFunction& lse_acc_element_func,
|
||||
const LSEElementFunction& lse_element_func,
|
||||
index_t o_acc_split_stride,
|
||||
index_t num_splits,
|
||||
void* smem_ptr) const
|
||||
@@ -166,6 +170,12 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline
|
||||
});
|
||||
}
|
||||
|
||||
// in case kStoreLSE is false, LSEDramBlockWindow is null
|
||||
if constexpr(!is_null_tile_window_v<LSEDramBlockWindow>)
|
||||
{
|
||||
store_tile(lse_dram_block_window, tile_elementwise_in(lse_element_func, lse_logsum));
|
||||
}
|
||||
|
||||
// calculate scale value (used for adjusting the o_acc) for all splits for all rows in
|
||||
// the tile
|
||||
lse_acc_type& lse_scale = lse_acc;
|
||||
@@ -235,16 +245,21 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename LSEaccDramBlockWindow, typename OAccDramBlockWindowTmp>
|
||||
template <typename LSEaccDramBlockWindow,
|
||||
typename OAccDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindow>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window_tmp,
|
||||
const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile
|
||||
LSEDramBlockWindow& lse_dram_block_window, // kM tile
|
||||
ck_tile::index_t o_acc_split_stride,
|
||||
index_t num_splits,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
return operator()(lse_acc_dram_block_window_tmp,
|
||||
o_acc_dram_block_window_tmp,
|
||||
lse_dram_block_window,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
o_acc_split_stride,
|
||||
|
||||
@@ -294,7 +294,11 @@ struct reference_no_group_hstu_attention_fwd
|
||||
|
||||
if(store_lse)
|
||||
{
|
||||
lse_batch_seq_nhead(i_batch, sq, i_head) = std::log(l) + m;
|
||||
if constexpr(kIsJagged)
|
||||
lse_batch_seq_nhead(0, seq_q_offsets[i_batch] + sq, i_head) =
|
||||
std::log(l) + m;
|
||||
else
|
||||
lse_batch_seq_nhead(i_batch, sq, i_head) = std::log(l) + m;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -583,7 +587,8 @@ struct reference_group_hstu_attention_fwd
|
||||
|
||||
if(store_lse)
|
||||
{
|
||||
lse_batch_seq_nhead(i_batch, sq, i_head) = std::log(l) + m;
|
||||
lse_batch_seq_nhead(0, seq_q_offsets[i_batch] + sq, i_head) =
|
||||
std::log(l) + m;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -9,10 +9,12 @@ if [ $# -ge 1 ]; then
|
||||
USE_SOFTMAX=$1
|
||||
fi
|
||||
|
||||
Training=${TEST_HSTU_FWD_TRAINING:-0}
|
||||
|
||||
if [ $USE_SOFTMAX -eq 1 ]; then
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training"
|
||||
else
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention"
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training"
|
||||
fi
|
||||
|
||||
dtype="bf16"
|
||||
|
||||
@@ -9,10 +9,12 @@ if [ $# -ge 1 ]; then
|
||||
USE_SOFTMAX=$1
|
||||
fi
|
||||
|
||||
Training=${TEST_HSTU_FWD_TRAINING:-0}
|
||||
|
||||
if [ $USE_SOFTMAX -eq 1 ]; then
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training"
|
||||
else
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention"
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training"
|
||||
fi
|
||||
|
||||
dtype="bf16"
|
||||
|
||||
@@ -7,10 +7,12 @@ if [ $# -ge 1 ]; then
|
||||
USE_SOFTMAX=$1
|
||||
fi
|
||||
|
||||
Training=${TEST_HSTU_FWD_TRAINING:-0}
|
||||
|
||||
if [ $USE_SOFTMAX -eq 1 ]; then
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -v=1 -softmax=1"
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -v=1 -softmax=1 -training=$Training"
|
||||
else
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -v=1"
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -v=1 -training=$Training"
|
||||
fi
|
||||
|
||||
set -x
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
BUILD=build
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
|
||||
Training=${TEST_HSTU_FWD_TRAINING:-0}
|
||||
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training"
|
||||
|
||||
ndist=0
|
||||
|
||||
|
||||
@@ -7,10 +7,12 @@ if [ $# -ge 1 ]; then
|
||||
USE_SOFTMAX=$1
|
||||
fi
|
||||
|
||||
Training=${TEST_HSTU_FWD_TRAINING:-0}
|
||||
|
||||
if [ $USE_SOFTMAX -eq 1 ]; then
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training"
|
||||
else
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention"
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training"
|
||||
fi
|
||||
|
||||
for T in "fp16" "bf16"; do
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
BUILD=build
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
|
||||
Training=${TEST_HSTU_FWD_TRAINING:-0}
|
||||
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training"
|
||||
|
||||
attn_scale=0
|
||||
if [ $# -ge 1 ]; then
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
## This script can be used the verifying the using of WarpGemm 32x32x16 which is used by hdim64 + softmax
|
||||
|
||||
BUILD=build
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
|
||||
Training=${TEST_HSTU_FWD_TRAINING:-0}
|
||||
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training"
|
||||
|
||||
attn_scale=1.0
|
||||
ndist=1
|
||||
|
||||
@@ -10,10 +10,12 @@ if [ $# -ge 1 ]; then
|
||||
USE_SOFTMAX=$1
|
||||
fi
|
||||
|
||||
Training=${TEST_HSTU_FWD_TRAINING:-0}
|
||||
|
||||
if [ $USE_SOFTMAX -eq 1 ]; then
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training"
|
||||
else
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention"
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training"
|
||||
fi
|
||||
|
||||
dtype="bf16"
|
||||
|
||||
@@ -8,7 +8,10 @@ fi
|
||||
|
||||
set +x
|
||||
BUILD=build
|
||||
EXE=$BUILD/bin/tile_example_hstu_attention
|
||||
|
||||
Training=${TEST_HSTU_FWD_TRAINING:-0}
|
||||
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training"
|
||||
|
||||
dtype="bf16"
|
||||
|
||||
|
||||
@@ -10,10 +10,12 @@ if [ $# -ge 1 ]; then
|
||||
USE_SOFTMAX=$1
|
||||
fi
|
||||
|
||||
Training=${TEST_HSTU_FWD_TRAINING:-0}
|
||||
|
||||
if [ $USE_SOFTMAX -eq 1 ]; then
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training"
|
||||
else
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention"
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training"
|
||||
fi
|
||||
|
||||
dtype="bf16"
|
||||
|
||||
Reference in New Issue
Block a user