This commit is contained in:
Juuso Korhonen
2025-11-24 10:20:04 +00:00
parent f552cd7841
commit f2fbc44b7b
3 changed files with 23 additions and 21 deletions

View File

@@ -238,14 +238,14 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
-1,
0,
seqlen_q,
seqlen_kv,
true));
// ck_tile::reference_batched_masking(
// s_host_ref,
// ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
// -1,
// 0,
// seqlen_q,
// seqlen_kv,
// true));
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
s_host_ref, p_host_ref, ck_tile::identity{});
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
@@ -526,6 +526,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
o_buf.FromDevice(o.data());
const auto [rtol, atol] = [&] {
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
return std::make_tuple(1e-3, 1e-3);

View File

@@ -310,18 +310,18 @@ struct UnifiedAttentionKernel
const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q);
const index_t seq_len = kargs.seq_lens_ptr[seq_idx];
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
// const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
index_t _max_seq_prefix_len =
amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1)
+ 1));
// index_t _max_seq_prefix_len =
// amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1)
// + 1));
if(seq_len < _max_seq_prefix_len)
{
_max_seq_prefix_len = seq_len;
}
// if(seq_len < _max_seq_prefix_len)
// {
// _max_seq_prefix_len = seq_len;
// }
const auto max_seq_prefix_len = _max_seq_prefix_len;
const auto max_seq_prefix_len = seq_len; // _max_seq_prefix_len;
const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE);
// TODO sliding window

View File

@@ -897,10 +897,11 @@ struct UnifiedAttentionPipeline
auto fmha_mask = [&](auto sp_reg_idx) {
if constexpr(FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
i_total_loops * BLOCK_SIZE,
number<BLOCK_M>{},
number<BLOCK_SIZE>{});
bool need_perpixel_check = false;
// mask.IsEdgeTile(q_origin.at(number<0>{}),
// i_total_loops * BLOCK_SIZE,
// number<BLOCK_M>{},
// number<BLOCK_SIZE>{});
if(need_perpixel_check)
{
set_tile_if(sp(sp_reg_idx).sp_compute,