Adjust P elementwise function

This commit is contained in:
rocking
2024-04-01 13:36:22 -04:00
committed by rocking
parent cf57626c07
commit bfcf550305
3 changed files with 6 additions and 6 deletions

View File

@@ -325,8 +325,8 @@ struct BlockFmhaPipelineQRKSVS
// STAGE 2, scale, add bias, mask, softmax
if constexpr(kHasBias)
{
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
@@ -341,9 +341,9 @@ struct BlockFmhaPipelineQRKSVS
}
else
{
tile_elementwise_inout(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(s_acc_element_func, s_acc);
#endif
}
move_tile_window(bias_dram_window, {0, kN0});

View File

@@ -368,8 +368,8 @@ struct BlockFmhaPipelineQRKSVSAsync
// STAGE 2, scale, add bias, mask, softmax
if constexpr(kHasBias)
{
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
@@ -384,9 +384,9 @@ struct BlockFmhaPipelineQRKSVSAsync
}
else
{
tile_elementwise_inout(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(s_acc_element_func, s_acc);
#endif
}
move_tile_window(bias_dram_window, {0, kN0});

View File

@@ -316,8 +316,8 @@ struct BlockFmhaPipelineQSKSVS
// STAGE 2, scale, add bias, mask, softmax
if constexpr(kHasBias)
{
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
@@ -332,9 +332,9 @@ struct BlockFmhaPipelineQSKSVS
}
else
{
tile_elementwise_inout(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(s_acc_element_func, s_acc);
#endif
}
move_tile_window(bias_dram_window, {0, kN0});