Fix bug of elementwise op, our elementwise op is not inout

This commit is contained in:
rocking
2024-04-04 03:17:36 +00:00
parent bfcf550305
commit d9323ea261
3 changed files with 10 additions and 10 deletions

View File

@@ -325,7 +325,7 @@ struct BlockFmhaPipelineQRKSVS
// STAGE 2, scale, add bias, mask, softmax
if constexpr(kHasBias)
{
tile_elementwise_inout(s_acc_element_func, s_acc);
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
@@ -341,7 +341,7 @@ struct BlockFmhaPipelineQRKSVS
}
else
{
tile_elementwise_inout(s_acc_element_func, s_acc);
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
#endif
@@ -468,7 +468,7 @@ struct BlockFmhaPipelineQRKSVS
}
move_tile_window(v_dram_window, {0, kK1});
tile_elementwise_inout(p_compute_element_func, p_compute);
tile_elementwise_in(p_compute_element_func, p_compute);
const auto p = cast_tile<PDataType>(p_compute);
// STAGE 3, KV gemm
@@ -555,7 +555,7 @@ struct BlockFmhaPipelineQRKSVS
});
});
tile_elementwise_inout(o_acc_element_func, o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}

View File

@@ -368,7 +368,7 @@ struct BlockFmhaPipelineQRKSVSAsync
// STAGE 2, scale, add bias, mask, softmax
if constexpr(kHasBias)
{
tile_elementwise_inout(s_acc_element_func, s_acc);
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
@@ -384,7 +384,7 @@ struct BlockFmhaPipelineQRKSVSAsync
}
else
{
tile_elementwise_inout(s_acc_element_func, s_acc);
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
#endif
@@ -650,7 +650,7 @@ struct BlockFmhaPipelineQRKSVSAsync
});
});
tile_elementwise_inout(o_acc_element_func, o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}

View File

@@ -316,7 +316,7 @@ struct BlockFmhaPipelineQSKSVS
// STAGE 2, scale, add bias, mask, softmax
if constexpr(kHasBias)
{
tile_elementwise_inout(s_acc_element_func, s_acc);
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
@@ -332,7 +332,7 @@ struct BlockFmhaPipelineQSKSVS
}
else
{
tile_elementwise_inout(s_acc_element_func, s_acc);
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
#endif
@@ -546,7 +546,7 @@ struct BlockFmhaPipelineQSKSVS
});
});
tile_elementwise_inout(o_acc_element_func, o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}