Add SAccElementFunction, PComputeElementFunction, OAccElementFunction in pipeline

This commit is contained in:
rocking
2024-03-29 02:55:16 -04:00
parent b0b8a5ad46
commit 50c36f352a
3 changed files with 54 additions and 12 deletions

View File

@@ -111,7 +111,10 @@ struct BlockFmhaPipelineQRKSVS
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename LSEElementFunction>
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
@@ -123,6 +126,9 @@ struct BlockFmhaPipelineQRKSVS
const BiasElementFunction& bias_element_func,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
float scale,
void* smem_ptr) const
@@ -319,13 +325,15 @@ struct BlockFmhaPipelineQRKSVS
// STAGE 2, scale, add bias, mask, softmax
if constexpr(kHasBias)
{
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(s_acc_element_func, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = scale * x + type_convert<SaccDataType>(bias_element_func(y));
x += type_convert<SaccDataType>(bias_element_func(y));
#else
x = scale * x + log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
x += log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
#endif
},
s_acc,
@@ -335,6 +343,7 @@ struct BlockFmhaPipelineQRKSVS
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(s_acc_element_func, s_acc);
#endif
}
move_tile_window(bias_dram_window, {0, kN0});
@@ -459,6 +468,7 @@ struct BlockFmhaPipelineQRKSVS
}
move_tile_window(v_dram_window, {0, kK1});
tile_elementwise_inout(p_compute_element_func, p_compute);
const auto p = cast_tile<PDataType>(p_compute);
// STAGE 3, KV gemm
@@ -545,6 +555,8 @@ struct BlockFmhaPipelineQRKSVS
});
});
tile_elementwise_inout(o_acc_element_func, o_acc);
return o_acc;
}
@@ -573,6 +585,9 @@ struct BlockFmhaPipelineQRKSVS
identity{},
lse_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
scale,
smem_ptr);

View File

@@ -122,7 +122,10 @@ struct BlockFmhaPipelineQRKSVSAsync
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename LSEElementFunction>
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
@@ -134,6 +137,9 @@ struct BlockFmhaPipelineQRKSVSAsync
const BiasElementFunction& bias_element_func,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
float scale,
void* smem_ptr) const
@@ -362,13 +368,15 @@ struct BlockFmhaPipelineQRKSVSAsync
// STAGE 2, scale, add bias, mask, softmax
if constexpr(kHasBias)
{
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(s_acc_element_func, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = scale * x + type_convert<SaccDataType>(bias_element_func(y));
x += type_convert<SaccDataType>(bias_element_func(y));
#else
x = scale * x + log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
x += log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
#endif
},
s_acc,
@@ -378,6 +386,7 @@ struct BlockFmhaPipelineQRKSVSAsync
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(s_acc_element_func, s_acc);
#endif
}
move_tile_window(bias_dram_window, {0, kN0});
@@ -521,6 +530,7 @@ struct BlockFmhaPipelineQRKSVSAsync
});
});
tile_elementwise_inout(p_compute_element_func, p_compute);
const auto p = cast_tile<PDataType>(p_compute);
// STAGE 3, KV gemm
@@ -640,6 +650,8 @@ struct BlockFmhaPipelineQRKSVSAsync
});
});
tile_elementwise_inout(o_acc_element_func, o_acc);
return o_acc;
}
@@ -668,6 +680,9 @@ struct BlockFmhaPipelineQRKSVSAsync
identity{},
lse_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
scale,
smem_ptr);

View File

@@ -97,7 +97,10 @@ struct BlockFmhaPipelineQSKSVS
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename LSEElementFunction>
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
@@ -109,6 +112,9 @@ struct BlockFmhaPipelineQSKSVS
const BiasElementFunction& bias_element_func,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
float scale,
void* smem_ptr) const
@@ -310,13 +316,15 @@ struct BlockFmhaPipelineQSKSVS
// STAGE 2, scale, add bias, mask, softmax
if constexpr(kHasBias)
{
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(s_acc_element_func, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = scale * x + type_convert<SaccDataType>(bias_element_func(y));
x += type_convert<SaccDataType>(bias_element_func(y));
#else
x = scale * x + log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
x += log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
#endif
},
s_acc,
@@ -326,6 +334,7 @@ struct BlockFmhaPipelineQSKSVS
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
tile_elementwise_inout(s_acc_element_func, s_acc);
#endif
}
move_tile_window(bias_dram_window, {0, kN0});
@@ -450,6 +459,7 @@ struct BlockFmhaPipelineQSKSVS
}
move_tile_window(v_dram_window, {0, kK1});
tile_elementwise_inout(p_compute_element_func, p_compute);
const auto p = cast_tile<PDataType>(p_compute);
// STAGE 3, KV gemm
@@ -536,6 +546,8 @@ struct BlockFmhaPipelineQSKSVS
});
});
tile_elementwise_inout(o_acc_element_func, o_acc);
return o_acc;
}