mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Add support of softmax in hstu attention
This commit is contained in:
@@ -48,6 +48,7 @@
|
||||
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
|
||||
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
|
||||
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
|
||||
.insert("softmax", "0", "use softmax or not")
|
||||
.insert("causal", "1", "enable causal mask or not")
|
||||
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
|
||||
@@ -104,6 +104,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
|
||||
.insert("targets", "", "sequence length at the end of query/key token sequence that should be excluded from attention")
|
||||
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
|
||||
.insert("softmax", "0", "use softmax or not")
|
||||
.insert("causal", "1", "enable causal mask or not")
|
||||
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
@@ -216,6 +217,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int num_head = arg_parser.get_int("nhead");
|
||||
int hdim_qk = arg_parser.get_int("hdim_qk");
|
||||
int hdim_v = arg_parser.get_int("hdim_v");
|
||||
bool use_softmax = static_cast<bool>(arg_parser.get_int("softmax"));
|
||||
bool use_causal = static_cast<bool>(arg_parser.get_int("causal"));
|
||||
|
||||
int window_size = arg_parser.get_int("local_len");
|
||||
@@ -434,6 +436,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
params.nhead_stride_bias = 0;
|
||||
params.nhead_stride_o = o_host_ref.get_strides()[2];
|
||||
params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer();
|
||||
params.use_softmax = use_softmax;
|
||||
params.use_causal = use_causal;
|
||||
params.window_size = window_size;
|
||||
params.contextual_seqlen = contextual_seqlen;
|
||||
@@ -473,6 +476,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
params.batch_stride_bias = 0;
|
||||
params.batch_stride_o = o_host_ref.get_strides()[0];
|
||||
params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer();
|
||||
params.use_softmax = use_softmax;
|
||||
params.use_causal = use_causal;
|
||||
params.window_size = window_size;
|
||||
params.contextual_seqlen = contextual_seqlen;
|
||||
@@ -513,11 +517,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using GemmAccDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType;
|
||||
using CompDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType;
|
||||
|
||||
BOOL_SWITCH_2(is_jagged, kIsJagged, use_causal, kUseCausal, [&] {
|
||||
BOOL_SWITCH_3(is_jagged, kIsJagged, use_softmax, kUseSoftmax, use_causal, kUseCausal, [&] {
|
||||
ck_tile::reference_hstu_attention<InOutDataType,
|
||||
GemmAccDataType,
|
||||
CompDataType,
|
||||
kIsJagged,
|
||||
kUseSoftmax,
|
||||
kUseCausal>::Run(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
|
||||
@@ -17,12 +17,14 @@ void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, hipStrea
|
||||
const bool use_causal = param.use_causal;
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
false, // using softmax
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
|
||||
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
});
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
@@ -17,12 +17,14 @@ void hstu_attention_batched_forward_fp16(HstuAttentionFwdParams& param, hipStrea
|
||||
const bool use_causal = param.use_causal;
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
false, // using softmax
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
|
||||
run_batched_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
});
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
@@ -245,7 +245,7 @@ struct HstuAttentionFwdKernel
|
||||
seq_stride_v,
|
||||
seq_stride_o,
|
||||
num_head,
|
||||
-scale_s,
|
||||
scale_s,
|
||||
attn_scale ? attn_scale : 1.0f / static_cast<float>(seqlen), // max_seqlen
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
@@ -320,7 +320,7 @@ struct HstuAttentionFwdKernel
|
||||
hdim_v,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
num_head,
|
||||
-scale_s,
|
||||
scale_s,
|
||||
attn_scale ? attn_scale : 1.0f / static_cast<float>(max_seqlen),
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
|
||||
@@ -156,6 +156,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr bool kUseSoftmax = Problem::kUseSoftmax;
|
||||
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
|
||||
@@ -173,6 +175,16 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
SaccBlockTileType sacc_tile;
|
||||
PcompBlockTileType pcomp_tile;
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
using MLBlockTileType = decltype(block_tile_reduce<CompDataType>(
|
||||
PcompBlockTileType{}, sequence<1>{}, f_max, CompDataType{0}));
|
||||
|
||||
auto m = MLBlockTileType{};
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
OaccBlockTileType o_acc;
|
||||
|
||||
@@ -279,15 +291,26 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_silu = [&](CompDataType& x) {
|
||||
const auto neg_one = ck_tile::type_convert<CompDataType>(-1.0f);
|
||||
const auto one = ck_tile::type_convert<CompDataType>(1.0f);
|
||||
|
||||
if constexpr(std::is_same_v<CompDataType, float>)
|
||||
{
|
||||
x = x * __builtin_amdgcn_rcpf(neg_one - __expf(x));
|
||||
x = x * __builtin_amdgcn_rcpf(one + __expf(-x));
|
||||
}
|
||||
else
|
||||
{
|
||||
x = x / (neg_one - exp(x));
|
||||
x = x / (one + exp(-x));
|
||||
}
|
||||
};
|
||||
|
||||
const auto f_exp = [&](CompDataType x) {
|
||||
if constexpr(std::is_same_v<CompDataType, float>)
|
||||
{
|
||||
return __expf(x);
|
||||
}
|
||||
else
|
||||
{
|
||||
return exp(x);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -353,6 +376,12 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
});
|
||||
|
||||
clear_tile(o_acc);
|
||||
|
||||
if constexpr(kUseSoftmax)
|
||||
{
|
||||
set_tile(m, -numeric<CompDataType>::infinity());
|
||||
clear_tile(l);
|
||||
};
|
||||
};
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
@@ -409,41 +438,141 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&scale_s, &bias_element_func](auto& x, const auto& y) {
|
||||
x = x * scale_s - type_convert<CompDataType>(bias_element_func(y));
|
||||
x = x * scale_s + type_convert<CompDataType>(bias_element_func(y));
|
||||
},
|
||||
pcomp_tile,
|
||||
bias_tile);
|
||||
|
||||
move_tile_window(bias_dram_window, {0, kK1});
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
|
||||
}
|
||||
|
||||
if(!mask.IsFullTileInsideMask(
|
||||
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
|
||||
if constexpr(!kUseSoftmax)
|
||||
{
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
if(!mask.IsFullTileInsideMask(
|
||||
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
|
||||
{
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(!mask.IsTokenPairInsideMask(row, col))
|
||||
pcomp_tile(i_j_idx) = static_cast<CompDataType>(0.0f);
|
||||
if(!mask.IsTokenPairInsideMask(row, col))
|
||||
{
|
||||
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
|
||||
};
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
tile_elementwise_inout(f_silu, pcomp_tile);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x) { x = x * type_convert<CompDataType>(scale_p); }, pcomp_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!mask.IsFullTileInsideMask(
|
||||
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
|
||||
{
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
tile_elementwise_inout(f_silu, pcomp_tile);
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
tile_elementwise_inout([&](auto& x) { x = x * type_convert<CompDataType>(scale_p); },
|
||||
pcomp_tile);
|
||||
if(!mask.IsTokenPairInsideMask(row, col) || col >= seqlen_k_end)
|
||||
{
|
||||
pcomp_tile(i_j_idx) = -numeric<CompDataType>::infinity();
|
||||
};
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(col >= seqlen_k_end)
|
||||
{
|
||||
pcomp_tile(i_j_idx) = -numeric<CompDataType>::infinity();
|
||||
};
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
auto m_local = block_tile_reduce<CompDataType>(
|
||||
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m;
|
||||
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
|
||||
|
||||
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
{
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
auto rowsum_p = block_tile_reduce<CompDataType>(
|
||||
pcomp_tile, sequence<1>{}, f_sum, CompDataType{0});
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
|
||||
// adjust o_acc[] according to the update between m and m_old
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
{
|
||||
l(i_idx) = rowsum_p[i_idx];
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto tmp = f_exp(m_old[i_idx] - m[i_idx]);
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
@@ -502,6 +631,23 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
};
|
||||
} while(seqlen_k_curr < seqlen_k_end);
|
||||
|
||||
if constexpr(kUseSoftmax)
|
||||
{
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if(m[i_idx] == -numeric<CompDataType>::infinity())
|
||||
o_acc(i_j_idx) = 0.0f;
|
||||
else
|
||||
o_acc(i_j_idx) *= 1.0f / l[i_idx];
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
|
||||
@@ -17,12 +17,14 @@ void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream
|
||||
const bool use_causal = param.use_causal;
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
false, // using softmax
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
|
||||
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
});
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
@@ -17,12 +17,14 @@ void hstu_attention_jagged_forward_fp16(HstuAttentionFwdParams& param, hipStream
|
||||
const bool use_causal = param.use_causal;
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
false, // using softmax
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
BOOL_SWITCH(param.use_softmax, kUseSoftmax, [&] {
|
||||
run_jagged_forward_causal_softmax_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
});
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
@@ -52,6 +52,8 @@ struct HstuAttentionFwdParams
|
||||
ck_tile::index_t contextual_seqlen;
|
||||
ck_tile::index_t min_full_attn_seqlen;
|
||||
|
||||
bool use_softmax;
|
||||
|
||||
float p_drop;
|
||||
uint64_t philox_seed;
|
||||
uint64_t philox_offset;
|
||||
|
||||
@@ -29,6 +29,7 @@ template <typename InOutDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CompDataType,
|
||||
bool kIsJagged,
|
||||
bool kUseSoftmax,
|
||||
bool kUseCausal>
|
||||
struct reference_hstu_attention
|
||||
{
|
||||
@@ -150,6 +151,11 @@ struct reference_hstu_attention
|
||||
// for all rows in the batch
|
||||
for(int sq = 0; sq < seqlen; sq++)
|
||||
{
|
||||
CompDataType m =
|
||||
-ck_tile::numeric<CompDataType>::infinity(); // max value of the row
|
||||
CompDataType l =
|
||||
ck_tile::type_convert<CompDataType>(0.0f); // sum of exp(x-m) of the row
|
||||
//
|
||||
std::vector<CompDataType> locals;
|
||||
|
||||
// for all cols in the batch
|
||||
@@ -186,12 +192,41 @@ struct reference_hstu_attention
|
||||
ck_tile::type_convert<CompDataType>(alpha));
|
||||
}
|
||||
else
|
||||
locals.push_back(ck_tile::type_convert<CompDataType>(0.0f));
|
||||
{
|
||||
if constexpr(!kUseSoftmax)
|
||||
locals.push_back(ck_tile::type_convert<CompDataType>(0.0f));
|
||||
else
|
||||
locals.push_back(-ck_tile::numeric<CompDataType>::infinity());
|
||||
};
|
||||
};
|
||||
|
||||
// SiLu element-wise
|
||||
for(CompDataType& elem : locals)
|
||||
elem = silu(elem) * ck_tile::type_convert<CompDataType>(scale_p);
|
||||
if constexpr(!kUseSoftmax)
|
||||
{
|
||||
// SiLu element-wise
|
||||
for(CompDataType& elem : locals)
|
||||
elem = silu(elem) * ck_tile::type_convert<CompDataType>(scale_p);
|
||||
}
|
||||
else
|
||||
{
|
||||
for(CompDataType elem : locals)
|
||||
m = ck_tile::max(m, elem);
|
||||
|
||||
if(m == -ck_tile::numeric<CompDataType>::infinity())
|
||||
{
|
||||
for(CompDataType& elem : locals)
|
||||
elem = ck_tile::type_convert<CompDataType>(0.0f);
|
||||
}
|
||||
else
|
||||
{
|
||||
// stabalized sum of exp()
|
||||
for(CompDataType elem : locals)
|
||||
l += std::exp(elem - m);
|
||||
|
||||
// normalization
|
||||
for(CompDataType& elem : locals)
|
||||
elem = std::exp(elem - m) / l;
|
||||
}
|
||||
};
|
||||
|
||||
// second Gemm
|
||||
for(int k = 0; k < hdim_v; k++)
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
#!/bin/bash
|
||||
|
||||
BUILD=build
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
|
||||
|
||||
attn_scale=0
|
||||
if [ $# -ge 1 ]; then
|
||||
attn_scale=$1
|
||||
fi
|
||||
|
||||
ndist=0
|
||||
|
||||
if [ $# -ge 2 ]; then
|
||||
ndist=$2
|
||||
fi
|
||||
|
||||
for dtype in "fp16" "bf16"; do
|
||||
set -x
|
||||
|
||||
## no masking batched
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## no masking jagged
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## batched causal
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## batched causal+local
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal+local
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## batched causal+local+context
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal+local+context
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## batched causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged no-causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal+local+target (minfull_len > max_uih_len)
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal+local+context+target (minfull_len > max_uih_len)
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged no-causal+local+context+target (minfull_len > max_uih_len)
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
set +x
|
||||
done
|
||||
Reference in New Issue
Block a user