mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Format codes
This commit is contained in:
@@ -259,12 +259,14 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
|
||||
batch_offset_lse_acc =
|
||||
static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
|
||||
batch_offset_o_acc = static_cast<long_index_t>(i_batch) *
|
||||
(kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v);
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
|
||||
batch_offset_lse =
|
||||
static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
|
||||
}
|
||||
batch_offset_o = query_start * kargs.row_stride_o;
|
||||
|
||||
|
||||
@@ -389,11 +389,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
get_x_indices_from_distributed_indices(o_acc_dist, distributed_indices);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
|
||||
LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits];
|
||||
o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);
|
||||
#if 0
|
||||
#if 0
|
||||
DEBUG_STMTS
|
||||
{
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
@@ -405,7 +404,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
lse_scale,
|
||||
o_tile(distributed_indices));
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user