mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
[CK_TILE] Change output accum tensor layout of fmha fwd split-kv & combine kernels (#1527)
* Use same layout for o_acc and o tensor * Use better param names in partitioner * Remove redundant kargs 'max_seqlen_q' * Use better param names in splitkv kernel * Add comment for additional kernel arguments * Sync empty loop early return logics between pipelines * Pass more arguments to cmake in scripts * Align backslashes * Fix wrong o_acc tensor view strides * Change o_acc layout if o_perm=0 * Handle whole row masked via attn_bias * Use use vector width = 1 for o_acc * Use more even split sizes
This commit is contained in:
@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const OaccElementFunction& o_acc_element_func,
|
||||
index_t num_splits,
|
||||
index_t max_seqlen_q,
|
||||
index_t seqlen_q,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
// lse_acc tile in LDS
|
||||
@@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
|
||||
clear_tile(o_acc);
|
||||
|
||||
const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0;
|
||||
const index_t padded_seqlen_q = integer_divide_ceil(seqlen_q, kM0) * kM0;
|
||||
|
||||
for(index_t i_split = 0; i_split < num_splits; ++i_split)
|
||||
{
|
||||
@@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
});
|
||||
}
|
||||
|
||||
move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0});
|
||||
move_tile_window(o_acc_dram_window, {padded_seqlen_q, 0});
|
||||
}
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
@@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const OaccDramBlockWindow& o_acc_dram_block_window,
|
||||
LSEDramBlockWindow& lse_dram_block_window,
|
||||
index_t num_splits,
|
||||
index_t max_seqlen_q,
|
||||
index_t seqlen_q,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
return operator()(lse_acc_dram_block_window,
|
||||
@@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
identity{},
|
||||
identity{},
|
||||
num_splits,
|
||||
max_seqlen_q,
|
||||
seqlen_q,
|
||||
smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking || kHasUnevenSplits)
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
||||
{
|
||||
const index_t original_num_total_loop =
|
||||
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user