mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Adjust P elementwise function
This commit is contained in:
@@ -325,8 +325,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
// STAGE 2, scale, add bias, mask, softmax
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
@@ -341,9 +341,9 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
#endif
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
|
||||
@@ -368,8 +368,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
// STAGE 2, scale, add bias, mask, softmax
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
@@ -384,9 +384,9 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
#endif
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
|
||||
@@ -316,8 +316,8 @@ struct BlockFmhaPipelineQSKSVS
|
||||
// STAGE 2, scale, add bias, mask, softmax
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
@@ -332,9 +332,9 @@ struct BlockFmhaPipelineQSKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
#endif
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
|
||||
Reference in New Issue
Block a user