mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
fix
This commit is contained in:
@@ -238,14 +238,14 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
|
||||
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
|
||||
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
|
||||
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||
-1,
|
||||
0,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
true));
|
||||
// ck_tile::reference_batched_masking(
|
||||
// s_host_ref,
|
||||
// ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||
// -1,
|
||||
// 0,
|
||||
// seqlen_q,
|
||||
// seqlen_kv,
|
||||
// true));
|
||||
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
|
||||
s_host_ref, p_host_ref, ck_tile::identity{});
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
|
||||
@@ -526,6 +526,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
|
||||
o_buf.FromDevice(o.data());
|
||||
|
||||
|
||||
const auto [rtol, atol] = [&] {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
|
||||
return std::make_tuple(1e-3, 1e-3);
|
||||
|
||||
@@ -310,18 +310,18 @@ struct UnifiedAttentionKernel
|
||||
const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q);
|
||||
const index_t seq_len = kargs.seq_lens_ptr[seq_idx];
|
||||
|
||||
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
|
||||
// const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
|
||||
|
||||
index_t _max_seq_prefix_len =
|
||||
amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1)
|
||||
+ 1));
|
||||
// index_t _max_seq_prefix_len =
|
||||
// amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1)
|
||||
// + 1));
|
||||
|
||||
if(seq_len < _max_seq_prefix_len)
|
||||
{
|
||||
_max_seq_prefix_len = seq_len;
|
||||
}
|
||||
// if(seq_len < _max_seq_prefix_len)
|
||||
// {
|
||||
// _max_seq_prefix_len = seq_len;
|
||||
// }
|
||||
|
||||
const auto max_seq_prefix_len = _max_seq_prefix_len;
|
||||
const auto max_seq_prefix_len = seq_len; // _max_seq_prefix_len;
|
||||
const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
||||
|
||||
// TODO sliding window
|
||||
|
||||
@@ -897,10 +897,11 @@ struct UnifiedAttentionPipeline
|
||||
auto fmha_mask = [&](auto sp_reg_idx) {
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
i_total_loops * BLOCK_SIZE,
|
||||
number<BLOCK_M>{},
|
||||
number<BLOCK_SIZE>{});
|
||||
bool need_perpixel_check = false;
|
||||
// mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
// i_total_loops * BLOCK_SIZE,
|
||||
// number<BLOCK_M>{},
|
||||
// number<BLOCK_SIZE>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(sp(sp_reg_idx).sp_compute,
|
||||
|
||||
Reference in New Issue
Block a user