mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Re-order pipeline call operator arguments
This commit is contained in:
@@ -421,9 +421,9 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
identity{}, // lse_element_func
|
||||
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
|
||||
kargs.num_splits,
|
||||
smem_ptr,
|
||||
kargs.seqlen_q,
|
||||
kargs.max_seqlen_q);
|
||||
kargs.max_seqlen_q,
|
||||
smem_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -431,9 +431,9 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
o_acc_dram_window,
|
||||
lse_dram_window,
|
||||
kargs.num_splits,
|
||||
smem_ptr,
|
||||
kargs.seqlen_q,
|
||||
kargs.max_seqlen_q);
|
||||
kargs.max_seqlen_q,
|
||||
smem_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -83,9 +83,9 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const OaccElementFunction& o_acc_element_func,
|
||||
index_t num_splits,
|
||||
void* smem_ptr,
|
||||
index_t real_seqlen_q,
|
||||
index_t max_seqlen_q) const
|
||||
index_t seqlen_q,
|
||||
index_t max_seqlen_q,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
LSEDataType* lse_acc_lds_ptr =
|
||||
static_cast<LSEDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
|
||||
@@ -116,7 +116,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
|
||||
if(row < num_splits && col < real_seqlen_q)
|
||||
if(row < num_splits && col < seqlen_q)
|
||||
{
|
||||
lse_acc_lds_ptr[row + col * kMaxSplits] = lse_acc(distributed_indices);
|
||||
}
|
||||
@@ -312,9 +312,9 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const OaccDramBlockWindow& o_acc_dram_block_window,
|
||||
LSEDramBlockWindow& lse_dram_block_window,
|
||||
index_t num_splits,
|
||||
void* smem_ptr,
|
||||
index_t real_seqlen_q,
|
||||
index_t max_seqlen_q) const
|
||||
index_t seqlen_q,
|
||||
index_t max_seqlen_q,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
return operator()(lse_acc_dram_block_window,
|
||||
o_acc_dram_block_window,
|
||||
@@ -322,9 +322,9 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
identity{},
|
||||
identity{},
|
||||
num_splits,
|
||||
smem_ptr,
|
||||
real_seqlen_q,
|
||||
max_seqlen_q);
|
||||
seqlen_q,
|
||||
max_seqlen_q,
|
||||
smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user