mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
Add SAccElementFunction, PComputeElementFunction, OAccElementFunction in pipeline
This commit is contained in:
@@ -111,7 +111,10 @@ struct BlockFmhaPipelineQRKSVS
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEElementFunction>
|
||||
typename LSEElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
@@ -123,6 +126,9 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const BiasElementFunction& bias_element_func,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
float scale,
|
||||
void* smem_ptr) const
|
||||
@@ -319,13 +325,15 @@ struct BlockFmhaPipelineQRKSVS
|
||||
// STAGE 2, scale, add bias, mask, softmax
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = scale * x + type_convert<SaccDataType>(bias_element_func(y));
|
||||
x += type_convert<SaccDataType>(bias_element_func(y));
|
||||
#else
|
||||
x = scale * x + log2e_v<SaccDataType> *
|
||||
type_convert<SaccDataType>(bias_element_func(y));
|
||||
x += log2e_v<SaccDataType> *
|
||||
type_convert<SaccDataType>(bias_element_func(y));
|
||||
#endif
|
||||
},
|
||||
s_acc,
|
||||
@@ -335,6 +343,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
#endif
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
@@ -459,6 +468,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
tile_elementwise_inout(p_compute_element_func, p_compute);
|
||||
const auto p = cast_tile<PDataType>(p_compute);
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
@@ -545,6 +555,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
});
|
||||
});
|
||||
|
||||
tile_elementwise_inout(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
@@ -573,6 +585,9 @@ struct BlockFmhaPipelineQRKSVS
|
||||
identity{},
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
scale,
|
||||
smem_ptr);
|
||||
|
||||
@@ -122,7 +122,10 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEElementFunction>
|
||||
typename LSEElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
@@ -134,6 +137,9 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const BiasElementFunction& bias_element_func,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
float scale,
|
||||
void* smem_ptr) const
|
||||
@@ -362,13 +368,15 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
// STAGE 2, scale, add bias, mask, softmax
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = scale * x + type_convert<SaccDataType>(bias_element_func(y));
|
||||
x += type_convert<SaccDataType>(bias_element_func(y));
|
||||
#else
|
||||
x = scale * x + log2e_v<SaccDataType> *
|
||||
type_convert<SaccDataType>(bias_element_func(y));
|
||||
x += log2e_v<SaccDataType> *
|
||||
type_convert<SaccDataType>(bias_element_func(y));
|
||||
#endif
|
||||
},
|
||||
s_acc,
|
||||
@@ -378,6 +386,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
#endif
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
@@ -521,6 +530,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
});
|
||||
});
|
||||
|
||||
tile_elementwise_inout(p_compute_element_func, p_compute);
|
||||
const auto p = cast_tile<PDataType>(p_compute);
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
@@ -640,6 +650,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
});
|
||||
});
|
||||
|
||||
tile_elementwise_inout(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
@@ -668,6 +680,9 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
identity{},
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
scale,
|
||||
smem_ptr);
|
||||
|
||||
@@ -97,7 +97,10 @@ struct BlockFmhaPipelineQSKSVS
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEElementFunction>
|
||||
typename LSEElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
@@ -109,6 +112,9 @@ struct BlockFmhaPipelineQSKSVS
|
||||
const BiasElementFunction& bias_element_func,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
float scale,
|
||||
void* smem_ptr) const
|
||||
@@ -310,13 +316,15 @@ struct BlockFmhaPipelineQSKSVS
|
||||
// STAGE 2, scale, add bias, mask, softmax
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x = scale * x + type_convert<SaccDataType>(bias_element_func(y));
|
||||
x += type_convert<SaccDataType>(bias_element_func(y));
|
||||
#else
|
||||
x = scale * x + log2e_v<SaccDataType> *
|
||||
type_convert<SaccDataType>(bias_element_func(y));
|
||||
x += log2e_v<SaccDataType> *
|
||||
type_convert<SaccDataType>(bias_element_func(y));
|
||||
#endif
|
||||
},
|
||||
s_acc,
|
||||
@@ -326,6 +334,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
|
||||
tile_elementwise_inout(s_acc_element_func, s_acc);
|
||||
#endif
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
@@ -450,6 +459,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
}
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
tile_elementwise_inout(p_compute_element_func, p_compute);
|
||||
const auto p = cast_tile<PDataType>(p_compute);
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
@@ -536,6 +546,8 @@ struct BlockFmhaPipelineQSKSVS
|
||||
});
|
||||
});
|
||||
|
||||
tile_elementwise_inout(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user