Re-order pipeline call operator arguments

This commit is contained in:
PoYen, Chen
2024-06-11 19:54:30 +00:00
parent 9d1243e7fa
commit ec82f3bbd6
2 changed files with 14 additions and 14 deletions

View File

@@ -421,9 +421,9 @@ struct FmhaFwdSplitKVCombineKernel
identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
kargs.num_splits,
smem_ptr,
kargs.seqlen_q,
kargs.max_seqlen_q);
kargs.max_seqlen_q,
smem_ptr);
}
else
{
@@ -431,9 +431,9 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window,
lse_dram_window,
kargs.num_splits,
smem_ptr,
kargs.seqlen_q,
kargs.max_seqlen_q);
kargs.max_seqlen_q,
smem_ptr);
}
}();

View File

@@ -83,9 +83,9 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const LSEElementFunction& lse_element_func,
const OaccElementFunction& o_acc_element_func,
index_t num_splits,
void* smem_ptr,
index_t real_seqlen_q,
index_t max_seqlen_q) const
index_t seqlen_q,
index_t max_seqlen_q,
void* smem_ptr) const
{
LSEDataType* lse_acc_lds_ptr =
static_cast<LSEDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
@@ -116,7 +116,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const auto row = x_indices.at(number<0>{});
const auto col = x_indices.at(number<1>{});
if(row < num_splits && col < real_seqlen_q)
if(row < num_splits && col < seqlen_q)
{
lse_acc_lds_ptr[row + col * kMaxSplits] = lse_acc(distributed_indices);
}
@@ -312,9 +312,9 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const OaccDramBlockWindow& o_acc_dram_block_window,
LSEDramBlockWindow& lse_dram_block_window,
index_t num_splits,
void* smem_ptr,
index_t real_seqlen_q,
index_t max_seqlen_q) const
index_t seqlen_q,
index_t max_seqlen_q,
void* smem_ptr) const
{
return operator()(lse_acc_dram_block_window,
o_acc_dram_block_window,
@@ -322,9 +322,9 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity{},
identity{},
num_splits,
smem_ptr,
real_seqlen_q,
max_seqlen_q);
seqlen_q,
max_seqlen_q,
smem_ptr);
}
};