Remove the k_element_func and v_element_func from the pipeline since they are not used

This commit is contained in:
Qianfeng Zhang
2025-11-13 08:43:05 +00:00
parent 881ddc5741
commit 95c1bb25e3
4 changed files with 23 additions and 66 deletions

View File

@@ -118,8 +118,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
@@ -128,10 +126,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
const SAccElementFunction& s_acc_element_func,
@@ -143,9 +139,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
void* smem_ptr,
DropoutType& dropout) const
{
ignore = q_element_func;
ignore = k_element_func;
static_assert(
std::is_same_v<QKVDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<QKVDataType,
@@ -380,8 +373,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
{
// STAGE 1, Gemm_0 ( S = Q@K )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(k_lds_write_windows[i_k1],
tile_elementwise_in(k_element_func, k_tiles[i_k1]));
store_tile(k_lds_write_windows[i_k1], k_tiles[i_k1]);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -464,8 +456,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffled_tile));
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -508,7 +499,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 1>{}]);
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffled_tile));
v_shuffled_tile);
__builtin_amdgcn_sched_barrier(0x00000001);
};
@@ -546,9 +537,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
identity{},

View File

@@ -120,8 +120,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
@@ -130,10 +128,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
const SAccElementFunction& s_acc_element_func,
@@ -145,9 +141,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
void* smem_ptr,
DropoutType& dropout) const
{
ignore = q_element_func;
ignore = k_element_func;
static_assert(
std::is_same_v<QKVDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<QKVDataType,
@@ -357,8 +350,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
{
// STAGE 1, Gemm_0 ( S = Q@K )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[i_k1]));
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}], k_tiles[i_k1]);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -453,7 +445,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
// STAGE 3, Gemm_1 ( O = P@V )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tiles[number<i_k1>{}]));
v_tiles[number<i_k1>{}]);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -498,9 +490,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
identity{},

View File

@@ -118,8 +118,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
@@ -128,10 +126,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
const SAccElementFunction& s_acc_element_func,
@@ -143,8 +139,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
void* smem_ptr,
DropoutType& dropout) const
{
ignore = q_element_func;
ignore = k_element_func;
ignore = scale_p;
static_assert(
@@ -396,9 +390,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
{
// STAGE 1, Gemm_0 ( S = Q@K )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1 % NumPrefetchK>{}]));
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
k_tiles[number<i_k1 % NumPrefetchK>{}]);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -505,8 +498,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffled_tile));
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -587,8 +579,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]);
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffled_tile));
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -616,7 +607,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 2>{}]);
store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffled_tile));
v_shuffled_tile);
__builtin_amdgcn_sched_barrier(0x00000001);
};
@@ -668,9 +659,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
identity{},

View File

@@ -120,8 +120,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
@@ -130,10 +128,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
const SAccElementFunction& s_acc_element_func,
@@ -145,8 +141,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
void* smem_ptr,
DropoutType& dropout) const
{
ignore = q_element_func;
ignore = k_element_func;
ignore = scale_p;
static_assert(
@@ -373,9 +367,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
{
// STAGE 1, Gemm_0 ( S = Q@K )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_k1 % NumPrefetchK>{}]));
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
k_tiles[number<i_k1 % NumPrefetchK>{}]);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -477,8 +470,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tiles[number<0>{}]));
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_tiles[number<0>{}]);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -564,8 +556,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tiles[number<1>{}]));
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_tiles[number<1>{}]);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -592,7 +583,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
__builtin_amdgcn_sched_barrier(0x00000001);
store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tiles[number<i_k1 + 2>{}]));
v_tiles[number<i_k1 + 2>{}]);
__builtin_amdgcn_sched_barrier(0x00000001);
};
@@ -644,9 +635,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
identity{},