mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Remove un-used element-wise functions passed through pipelines' operator() interfaces
This commit is contained in:
@@ -115,22 +115,12 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask& mask,
|
||||
@@ -181,7 +171,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
if(seqlen_k_end <= seqlen_k_start)
|
||||
{
|
||||
clear_tile(o_acc);
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
};
|
||||
@@ -302,8 +291,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
clear_tile(o_acc);
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
auto seqlen_k_curr = seqlen_k_start;
|
||||
|
||||
using v_tile_type = decltype(load_tile(v_dram_window));
|
||||
@@ -329,8 +316,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
|
||||
set_slice_tile(pcomp_tile,
|
||||
@@ -347,8 +332,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&scale_s, &bias_element_func](auto& x, const auto& y) {
|
||||
x = x * scale_s + type_convert<CompDataType>(bias_element_func(y));
|
||||
[&scale_s](auto& x, const auto& y) {
|
||||
x = x * scale_s + type_convert<CompDataType>(y);
|
||||
},
|
||||
pcomp_tile,
|
||||
bias_tile);
|
||||
@@ -420,7 +405,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
}
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
auto p = cast_tile<PDataType>(pcomp_tile);
|
||||
|
||||
// STAGE 3, Gemm_1 ( O = P@V )
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
@@ -458,46 +443,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
};
|
||||
} while(seqlen_k_curr < seqlen_k_end);
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLU result
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
v_dram_block_window_tmp,
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
scale_s,
|
||||
scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -55,10 +55,9 @@ struct HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename OAccDramBlockWindowTmp, typename OAccElementFunction>
|
||||
template <typename OAccDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // M0*kOHeaddim tile
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
ck_tile::index_t o_acc_split_stride,
|
||||
ck_tile::index_t num_splits) const
|
||||
{
|
||||
@@ -88,19 +87,8 @@ struct HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline
|
||||
tile_elementwise_inout([](auto& x, const auto& y) { x = x + y; }, o_acc, o_acc_tile);
|
||||
};
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename OAccDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile
|
||||
ck_tile::index_t o_acc_split_stride,
|
||||
ck_tile::index_t num_splits) const
|
||||
{
|
||||
return operator()(o_acc_dram_block_window_tmp, identity{}, o_acc_split_stride, num_splits);
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -116,22 +116,12 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask& mask,
|
||||
@@ -182,7 +172,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
if(seqlen_k_end <= seqlen_k_start)
|
||||
{
|
||||
clear_tile(o_acc);
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
};
|
||||
@@ -309,8 +298,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
clear_tile(o_acc);
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
auto seqlen_k_curr = seqlen_k_start;
|
||||
|
||||
using v_tile_type = decltype(load_tile(v_dram_window));
|
||||
@@ -338,8 +325,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
|
||||
set_slice_tile(pcomp_tile,
|
||||
@@ -356,8 +341,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&scale_s, &bias_element_func](auto& x, const auto& y) {
|
||||
x = x * scale_s + type_convert<CompDataType>(bias_element_func(y));
|
||||
[&scale_s](auto& x, const auto& y) {
|
||||
x = x * scale_s + type_convert<CompDataType>(y);
|
||||
},
|
||||
pcomp_tile,
|
||||
bias_tile);
|
||||
@@ -408,7 +393,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
}
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
auto p = cast_tile<PDataType>(pcomp_tile);
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
|
||||
@@ -442,46 +427,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
});
|
||||
} while(seqlen_k_curr < seqlen_k_end);
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLU result
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
v_dram_block_window_tmp,
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
scale_s,
|
||||
scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -120,25 +120,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindow,
|
||||
typename QElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEaccElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
const LSEaccElementFunction& lse_or_lse_acc_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask& mask,
|
||||
@@ -202,7 +190,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
if(seqlen_k_end <= seqlen_k_start)
|
||||
{
|
||||
clear_tile(o_acc);
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindow>)
|
||||
{
|
||||
@@ -211,8 +198,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
set_tile(lse_or_lse_acc, -numeric<CompDataType>::infinity());
|
||||
|
||||
store_tile(lse_or_lse_acc_dram_block_window,
|
||||
tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc));
|
||||
store_tile(lse_or_lse_acc_dram_block_window, lse_or_lse_acc);
|
||||
}
|
||||
|
||||
return o_acc;
|
||||
@@ -338,8 +324,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
set_tile(m, -numeric<CompDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
auto seqlen_k_curr = seqlen_k_start;
|
||||
|
||||
constexpr index_t NumPrefetchV = 2;
|
||||
@@ -381,8 +365,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
|
||||
set_slice_tile(pcomp_tile,
|
||||
@@ -397,8 +379,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&scale_s, &bias_element_func](auto& x, const auto& y) {
|
||||
x = x * scale_s + type_convert<CompDataType>(bias_element_func(y));
|
||||
[&scale_s](auto& x, const auto& y) {
|
||||
x = x * scale_s + type_convert<CompDataType>(y);
|
||||
},
|
||||
pcomp_tile,
|
||||
bias_tile);
|
||||
@@ -549,7 +531,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
}
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
auto p = cast_tile<PDataType>(pcomp_tile);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -612,8 +594,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
lse_or_lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
|
||||
});
|
||||
|
||||
store_tile(lse_or_lse_acc_dram_block_window,
|
||||
tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc));
|
||||
store_tile(lse_or_lse_acc_dram_block_window, lse_or_lse_acc);
|
||||
}
|
||||
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
@@ -632,50 +613,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
});
|
||||
});
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindow,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLU result
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
v_dram_block_window_tmp,
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
lse_or_lse_acc_dram_block_window,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
scale_s,
|
||||
scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -63,17 +63,11 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline
|
||||
|
||||
template <typename LSEaccDramBlockWindowTmp,
|
||||
typename OAccDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindow,
|
||||
typename OAccElementFunction,
|
||||
typename LSEaccElementFunction,
|
||||
typename LSEElementFunction>
|
||||
typename LSEDramBlockWindow>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // kM tile
|
||||
const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile
|
||||
LSEDramBlockWindow& lse_dram_block_window, // kM tile
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
const LSEaccElementFunction& lse_acc_element_func,
|
||||
const LSEElementFunction& lse_element_func,
|
||||
index_t o_acc_split_stride,
|
||||
index_t num_splits,
|
||||
void* smem_ptr) const
|
||||
@@ -110,8 +104,6 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline
|
||||
|
||||
auto lse_acc = load_tile(lse_acc_dram_window);
|
||||
|
||||
lse_acc = tile_elementwise_in(lse_acc_element_func, lse_acc);
|
||||
|
||||
using lse_acc_type = decltype(lse_acc);
|
||||
constexpr auto lse_spans = lse_acc_type::get_distributed_spans();
|
||||
|
||||
@@ -173,7 +165,7 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline
|
||||
// in case kStoreLSE is false, LSEDramBlockWindow is null
|
||||
if constexpr(!is_null_tile_window_v<LSEDramBlockWindow>)
|
||||
{
|
||||
store_tile(lse_dram_block_window, tile_elementwise_in(lse_element_func, lse_logsum));
|
||||
store_tile(lse_dram_block_window, lse_logsum);
|
||||
}
|
||||
|
||||
// calculate scale value (used for adjusting the o_acc) for all splits for all rows in
|
||||
@@ -240,32 +232,8 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline
|
||||
});
|
||||
};
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename LSEaccDramBlockWindow,
|
||||
typename OAccDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindow>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window_tmp,
|
||||
const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile
|
||||
LSEDramBlockWindow& lse_dram_block_window, // kM tile
|
||||
ck_tile::index_t o_acc_split_stride,
|
||||
index_t num_splits,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
return operator()(lse_acc_dram_block_window_tmp,
|
||||
o_acc_dram_block_window_tmp,
|
||||
lse_dram_block_window,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
o_acc_split_stride,
|
||||
num_splits,
|
||||
smem_ptr);
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -120,25 +120,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindow,
|
||||
typename QElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEaccElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
const LSEaccElementFunction& lse_or_lse_acc_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask& mask,
|
||||
@@ -203,7 +191,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
if(seqlen_k_end <= seqlen_k_start)
|
||||
{
|
||||
clear_tile(o_acc);
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindow>)
|
||||
{
|
||||
@@ -212,8 +199,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
set_tile(lse_or_lse_acc, -numeric<CompDataType>::infinity());
|
||||
|
||||
store_tile(lse_or_lse_acc_dram_block_window,
|
||||
tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc));
|
||||
store_tile(lse_or_lse_acc_dram_block_window, lse_or_lse_acc);
|
||||
}
|
||||
|
||||
return o_acc;
|
||||
@@ -345,8 +331,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
set_tile(m, -numeric<CompDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
auto seqlen_k_curr = seqlen_k_start;
|
||||
|
||||
constexpr index_t NumPrefetchV = 2;
|
||||
@@ -388,8 +372,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
|
||||
set_slice_tile(pcomp_tile,
|
||||
@@ -406,8 +388,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&scale_s, &bias_element_func](auto& x, const auto& y) {
|
||||
x = x * scale_s + type_convert<CompDataType>(bias_element_func(y));
|
||||
[&scale_s](auto& x, const auto& y) {
|
||||
x = x * scale_s + type_convert<CompDataType>(y);
|
||||
},
|
||||
pcomp_tile,
|
||||
bias_tile);
|
||||
@@ -556,7 +538,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
}
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
auto p = cast_tile<PDataType>(pcomp_tile);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -618,8 +600,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
lse_or_lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
|
||||
});
|
||||
|
||||
store_tile(lse_or_lse_acc_dram_block_window,
|
||||
tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc));
|
||||
store_tile(lse_or_lse_acc_dram_block_window, lse_or_lse_acc);
|
||||
}
|
||||
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
@@ -638,50 +619,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
});
|
||||
});
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindow,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLU result
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
v_dram_block_window_tmp,
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
lse_or_lse_acc_dram_block_window,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
scale_s,
|
||||
scale_p,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user