mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Fix bug of elementwise op, our elementwise op is not inout
This commit is contained in:
@@ -325,7 +325,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
// STAGE 2, scale, add bias, mask, softmax
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
@@ -341,7 +341,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
#endif
|
||||
@@ -468,7 +468,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
tile_elementwise_inout(p_compute_element_func, p_compute);
|
||||
tile_elementwise_in(p_compute_element_func, p_compute);
|
||||
const auto p = cast_tile<PDataType>(p_compute);
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
@@ -555,7 +555,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
});
|
||||
});
|
||||
|
||||
tile_elementwise_inout(o_acc_element_func, o_acc);
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
@@ -368,7 +368,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
// STAGE 2, scale, add bias, mask, softmax
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
@@ -384,7 +384,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
#endif
|
||||
@@ -650,7 +650,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
});
|
||||
});
|
||||
|
||||
tile_elementwise_inout(o_acc_element_func, o_acc);
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
@@ -316,7 +316,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
// STAGE 2, scale, add bias, mask, softmax
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
@@ -332,7 +332,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
#endif
|
||||
@@ -546,7 +546,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
});
|
||||
});
|
||||
|
||||
tile_elementwise_inout(o_acc_element_func, o_acc);
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user