Remove un-used element-wise functions passed through pipelines' operator() interfaces

This commit is contained in:
Qianfeng Zhang
2026-06-12 08:40:18 +00:00
parent cc7e216fa6
commit d07b37c097
6 changed files with 29 additions and 301 deletions

View File

@@ -115,22 +115,12 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename QElementFunction,
typename BiasElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const QElementFunction& q_element_func,
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask& mask,
@@ -181,7 +171,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
if(seqlen_k_end <= seqlen_k_start)
{
clear_tile(o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
};
@@ -302,8 +291,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
clear_tile(o_acc);
q_tile = tile_elementwise_in(q_element_func, q_tile);
auto seqlen_k_curr = seqlen_k_start;
using v_tile_type = decltype(load_tile(v_dram_window));
@@ -329,8 +316,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
@@ -347,8 +332,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
const auto bias_tile = load_tile(bias_dram_window);
tile_elementwise_inout(
[&scale_s, &bias_element_func](auto& x, const auto& y) {
x = x * scale_s + type_convert<CompDataType>(bias_element_func(y));
[&scale_s](auto& x, const auto& y) {
x = x * scale_s + type_convert<CompDataType>(y);
},
pcomp_tile,
bias_tile);
@@ -420,7 +405,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
}
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
auto p = cast_tile<PDataType>(pcomp_tile);
// STAGE 3, Gemm_1 ( O = P@V )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
@@ -458,46 +443,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
};
} while(seqlen_k_curr < seqlen_k_end);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask mask,
float scale_s, // scaling value exerted on the immediate Q@K result
float scale_p, // scaling value exerted on the SiLU result
void* smem_ptr,
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
v_dram_block_window_tmp,
bias_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
seqlen_k_start,
seqlen_k_end,
mask,
scale_s,
scale_p,
smem_ptr,
dropout);
}
};
} // namespace ck_tile

View File

@@ -55,10 +55,9 @@ struct HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline
return Policy::template GetSmemSize<Problem>();
}
template <typename OAccDramBlockWindowTmp, typename OAccElementFunction>
template <typename OAccDramBlockWindowTmp>
CK_TILE_DEVICE auto
operator()(const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // M0*kOHeaddim tile
const OAccElementFunction& o_acc_element_func,
ck_tile::index_t o_acc_split_stride,
ck_tile::index_t num_splits) const
{
@@ -88,19 +87,8 @@ struct HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline
tile_elementwise_inout([](auto& x, const auto& y) { x = x + y; }, o_acc, o_acc_tile);
};
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename OAccDramBlockWindowTmp>
CK_TILE_DEVICE auto
operator()(const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile
ck_tile::index_t o_acc_split_stride,
ck_tile::index_t num_splits) const
{
return operator()(o_acc_dram_block_window_tmp, identity{}, o_acc_split_stride, num_splits);
};
};
} // namespace ck_tile

View File

@@ -116,22 +116,12 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename QElementFunction,
typename BiasElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const QElementFunction& q_element_func,
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask& mask,
@@ -182,7 +172,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
if(seqlen_k_end <= seqlen_k_start)
{
clear_tile(o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
};
@@ -309,8 +298,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
clear_tile(o_acc);
q_tile = tile_elementwise_in(q_element_func, q_tile);
auto seqlen_k_curr = seqlen_k_start;
using v_tile_type = decltype(load_tile(v_dram_window));
@@ -338,8 +325,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
@@ -356,8 +341,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
const auto bias_tile = load_tile(bias_dram_window);
tile_elementwise_inout(
[&scale_s, &bias_element_func](auto& x, const auto& y) {
x = x * scale_s + type_convert<CompDataType>(bias_element_func(y));
[&scale_s](auto& x, const auto& y) {
x = x * scale_s + type_convert<CompDataType>(y);
},
pcomp_tile,
bias_tile);
@@ -408,7 +393,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
}
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
auto p = cast_tile<PDataType>(pcomp_tile);
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
@@ -442,46 +427,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
});
} while(seqlen_k_curr < seqlen_k_end);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask mask,
float scale_s, // scaling value exerted on the immediate Q@K result
float scale_p, // scaling value exerted on the SiLU result
void* smem_ptr,
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
v_dram_block_window_tmp,
bias_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
seqlen_k_start,
seqlen_k_end,
mask,
scale_s,
scale_p,
smem_ptr,
dropout);
}
};
} // namespace ck_tile

View File

@@ -120,25 +120,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindow,
typename QElementFunction,
typename BiasElementFunction,
typename LSEaccElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
const LSEaccElementFunction& lse_or_lse_acc_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask& mask,
@@ -202,7 +190,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
if(seqlen_k_end <= seqlen_k_start)
{
clear_tile(o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindow>)
{
@@ -211,8 +198,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
set_tile(lse_or_lse_acc, -numeric<CompDataType>::infinity());
store_tile(lse_or_lse_acc_dram_block_window,
tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc));
store_tile(lse_or_lse_acc_dram_block_window, lse_or_lse_acc);
}
return o_acc;
@@ -338,8 +324,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
set_tile(m, -numeric<CompDataType>::infinity());
clear_tile(l);
q_tile = tile_elementwise_in(q_element_func, q_tile);
auto seqlen_k_curr = seqlen_k_start;
constexpr index_t NumPrefetchV = 2;
@@ -381,8 +365,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
@@ -397,8 +379,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
const auto bias_tile = load_tile(bias_dram_window);
tile_elementwise_inout(
[&scale_s, &bias_element_func](auto& x, const auto& y) {
x = x * scale_s + type_convert<CompDataType>(bias_element_func(y));
[&scale_s](auto& x, const auto& y) {
x = x * scale_s + type_convert<CompDataType>(y);
},
pcomp_tile,
bias_tile);
@@ -549,7 +531,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
}
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
auto p = cast_tile<PDataType>(pcomp_tile);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -612,8 +594,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
lse_or_lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
});
store_tile(lse_or_lse_acc_dram_block_window,
tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc));
store_tile(lse_or_lse_acc_dram_block_window, lse_or_lse_acc);
}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
@@ -632,50 +613,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindow,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask mask,
float scale_s, // scaling value exerted on the immediate Q@K result
float scale_p, // scaling value exerted on the SiLU result
void* smem_ptr,
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
v_dram_block_window_tmp,
bias_dram_block_window_tmp,
identity{},
lse_or_lse_acc_dram_block_window,
identity{},
identity{},
identity{},
identity{},
seqlen_k_start,
seqlen_k_end,
mask,
scale_s,
scale_p,
smem_ptr,
dropout);
}
};
} // namespace ck_tile

View File

@@ -63,17 +63,11 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline
template <typename LSEaccDramBlockWindowTmp,
typename OAccDramBlockWindowTmp,
typename LSEDramBlockWindow,
typename OAccElementFunction,
typename LSEaccElementFunction,
typename LSEElementFunction>
typename LSEDramBlockWindow>
CK_TILE_DEVICE auto
operator()(const LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // kM tile
const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile
LSEDramBlockWindow& lse_dram_block_window, // kM tile
const OAccElementFunction& o_acc_element_func,
const LSEaccElementFunction& lse_acc_element_func,
const LSEElementFunction& lse_element_func,
index_t o_acc_split_stride,
index_t num_splits,
void* smem_ptr) const
@@ -110,8 +104,6 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline
auto lse_acc = load_tile(lse_acc_dram_window);
lse_acc = tile_elementwise_in(lse_acc_element_func, lse_acc);
using lse_acc_type = decltype(lse_acc);
constexpr auto lse_spans = lse_acc_type::get_distributed_spans();
@@ -173,7 +165,7 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline
// in case kStoreLSE is false, LSEDramBlockWindow is null
if constexpr(!is_null_tile_window_v<LSEDramBlockWindow>)
{
store_tile(lse_dram_block_window, tile_elementwise_in(lse_element_func, lse_logsum));
store_tile(lse_dram_block_window, lse_logsum);
}
// calculate scale value (used for adjusting the o_acc) for all splits for all rows in
@@ -240,32 +232,8 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline
});
};
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename LSEaccDramBlockWindow,
typename OAccDramBlockWindowTmp,
typename LSEDramBlockWindow>
CK_TILE_DEVICE auto
operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window_tmp,
const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile
LSEDramBlockWindow& lse_dram_block_window, // kM tile
ck_tile::index_t o_acc_split_stride,
index_t num_splits,
void* smem_ptr) const
{
return operator()(lse_acc_dram_block_window_tmp,
o_acc_dram_block_window_tmp,
lse_dram_block_window,
identity{},
identity{},
identity{},
o_acc_split_stride,
num_splits,
smem_ptr);
};
};
} // namespace ck_tile

View File

@@ -120,25 +120,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindow,
typename QElementFunction,
typename BiasElementFunction,
typename LSEaccElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
const LSEaccElementFunction& lse_or_lse_acc_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask& mask,
@@ -203,7 +191,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
if(seqlen_k_end <= seqlen_k_start)
{
clear_tile(o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindow>)
{
@@ -212,8 +199,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
set_tile(lse_or_lse_acc, -numeric<CompDataType>::infinity());
store_tile(lse_or_lse_acc_dram_block_window,
tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc));
store_tile(lse_or_lse_acc_dram_block_window, lse_or_lse_acc);
}
return o_acc;
@@ -345,8 +331,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
set_tile(m, -numeric<CompDataType>::infinity());
clear_tile(l);
q_tile = tile_elementwise_in(q_element_func, q_tile);
auto seqlen_k_curr = seqlen_k_start;
constexpr index_t NumPrefetchV = 2;
@@ -388,8 +372,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
@@ -406,8 +388,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
const auto bias_tile = load_tile(bias_dram_window);
tile_elementwise_inout(
[&scale_s, &bias_element_func](auto& x, const auto& y) {
x = x * scale_s + type_convert<CompDataType>(bias_element_func(y));
[&scale_s](auto& x, const auto& y) {
x = x * scale_s + type_convert<CompDataType>(y);
},
pcomp_tile,
bias_tile);
@@ -556,7 +538,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
}
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
auto p = cast_tile<PDataType>(pcomp_tile);
__builtin_amdgcn_sched_barrier(0x00000001);
@@ -618,8 +600,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
lse_or_lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
});
store_tile(lse_or_lse_acc_dram_block_window,
tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc));
store_tile(lse_or_lse_acc_dram_block_window, lse_or_lse_acc);
}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
@@ -638,50 +619,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindow,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask mask,
float scale_s, // scaling value exerted on the immediate Q@K result
float scale_p, // scaling value exerted on the SiLU result
void* smem_ptr,
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
v_dram_block_window_tmp,
bias_dram_block_window_tmp,
identity{},
lse_or_lse_acc_dram_block_window,
identity{},
identity{},
identity{},
identity{},
seqlen_k_start,
seqlen_k_end,
mask,
scale_s,
scale_p,
smem_ptr,
dropout);
}
};
} // namespace ck_tile