mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Remove the k_element_func and v_element_func from the pipeline since they are not used
This commit is contained in:
@@ -118,8 +118,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
@@ -128,10 +126,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
@@ -143,9 +139,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
ignore = q_element_func;
|
||||
ignore = k_element_func;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QKVDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QKVDataType,
|
||||
@@ -380,8 +373,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[i_k1],
|
||||
tile_elementwise_in(k_element_func, k_tiles[i_k1]));
|
||||
store_tile(k_lds_write_windows[i_k1], k_tiles[i_k1]);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -464,8 +456,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffled_tile));
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -508,7 +499,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 1>{}]);
|
||||
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffled_tile));
|
||||
v_shuffled_tile);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
};
|
||||
@@ -546,9 +537,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
|
||||
@@ -120,8 +120,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
@@ -130,10 +128,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
@@ -145,9 +141,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
ignore = q_element_func;
|
||||
ignore = k_element_func;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QKVDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QKVDataType,
|
||||
@@ -357,8 +350,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[i_k1]));
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}], k_tiles[i_k1]);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -453,7 +445,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
// STAGE 3, Gemm_1 ( O = P@V )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_tiles[number<i_k1>{}]));
|
||||
v_tiles[number<i_k1>{}]);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -498,9 +490,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
|
||||
@@ -118,8 +118,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
@@ -128,10 +126,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
@@ -143,8 +139,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
ignore = q_element_func;
|
||||
ignore = k_element_func;
|
||||
ignore = scale_p;
|
||||
|
||||
static_assert(
|
||||
@@ -396,9 +390,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(
|
||||
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1 % NumPrefetchK>{}]));
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
k_tiles[number<i_k1 % NumPrefetchK>{}]);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -505,8 +498,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffled_tile));
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -587,8 +579,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]);
|
||||
|
||||
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffled_tile));
|
||||
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -616,7 +607,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
shuffle_tile(v_shuffled_tile, v_tiles[number<i_k1 + 2>{}]);
|
||||
store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffled_tile));
|
||||
v_shuffled_tile);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
};
|
||||
@@ -668,9 +659,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
|
||||
@@ -120,8 +120,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
@@ -130,10 +128,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
@@ -145,8 +141,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
ignore = q_element_func;
|
||||
ignore = k_element_func;
|
||||
ignore = scale_p;
|
||||
|
||||
static_assert(
|
||||
@@ -373,9 +367,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
{
|
||||
// STAGE 1, Gemm_0 ( S = Q@K )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
store_tile(
|
||||
k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_k1 % NumPrefetchK>{}]));
|
||||
store_tile(k_lds_write_windows[number<i_k1 % NumKVLdsBuffers>{}],
|
||||
k_tiles[number<i_k1 % NumPrefetchK>{}]);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -477,8 +470,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_tiles[number<0>{}]));
|
||||
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_tiles[number<0>{}]);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -564,8 +556,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
|
||||
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_tiles[number<1>{}]));
|
||||
store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_tiles[number<1>{}]);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -592,7 +583,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_tiles[number<i_k1 + 2>{}]));
|
||||
v_tiles[number<i_k1 + 2>{}]);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
};
|
||||
@@ -644,9 +635,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
|
||||
Reference in New Issue
Block a user