Set invalid element value for LSEacc tensor view

This commit is contained in:
PoYen, Chen
2024-06-12 02:53:55 +00:00
parent ff866f6bb6
commit a939ec5da4

View File

@@ -333,6 +333,7 @@ struct FmhaFwdSplitKVCombineKernel
lse_acc_ptr,
make_tuple(kargs.num_splits, kargs.seqlen_q),
make_tuple(kargs.split_stride_lse_acc, 1),
-numeric<LSEDataType>::infinity(),
number<8>{},
number<1>{});
@@ -421,7 +422,6 @@ struct FmhaFwdSplitKVCombineKernel
identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
kargs.num_splits,
kargs.seqlen_q,
kargs.max_seqlen_q,
smem_ptr);
}
@@ -431,7 +431,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window,
lse_dram_window,
kargs.num_splits,
kargs.seqlen_q,
kargs.max_seqlen_q,
smem_ptr);
}