Format codes

This commit is contained in:
PoYen, Chen
2024-06-06 04:36:49 +00:00
parent 18a7223b96
commit ffd2768000
2 changed files with 6 additions and 5 deletions

View File

@@ -259,12 +259,14 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
batch_offset_o_acc = static_cast<long_index_t>(i_batch) *
(kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v);
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
batch_offset_lse =
static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
}
batch_offset_o = query_start * kargs.row_stride_o;

View File

@@ -389,11 +389,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline
get_x_indices_from_distributed_indices(o_acc_dist, distributed_indices);
const auto row = x_indices.at(number<0>{});
LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits];
o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);
#if 0
#if 0
DEBUG_STMTS
{
const auto col = x_indices.at(number<1>{});
@@ -405,7 +404,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
lse_scale,
o_tile(distributed_indices));
}
#endif
#endif
});
});
}