mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
This reverts commit dd0b4294af.
This commit is contained in:
@@ -12,7 +12,6 @@ enum class BlockAttentionQuantScaleEnum
|
||||
{
|
||||
NO_SCALE = 0,
|
||||
PERTENSOR = 1,
|
||||
BLOCKSCALE,
|
||||
};
|
||||
|
||||
template <BlockAttentionQuantScaleEnum>
|
||||
@@ -28,10 +27,5 @@ struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::PERTENSOR
|
||||
{
|
||||
static constexpr const char* name = "pertensor";
|
||||
};
|
||||
template <>
|
||||
struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::BLOCKSCALE>
|
||||
{
|
||||
static constexpr const char* name = "blockscale";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -168,29 +168,6 @@ struct FmhaFwdKernel
|
||||
const void* v_descale_ptr = nullptr;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonBlockScaleKargs : public FmhaFwdCommonQScaleKargs
|
||||
{
|
||||
ck_tile::index_t nhead_stride_q_descale;
|
||||
ck_tile::index_t nhead_stride_k_descale;
|
||||
ck_tile::index_t nhead_stride_v_descale;
|
||||
|
||||
ck_tile::index_t block_scale_size_q;
|
||||
ck_tile::index_t block_scale_size_kv;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_q_descale;
|
||||
ck_tile::index_t batch_stride_k_descale;
|
||||
ck_tile::index_t batch_stride_v_descale;
|
||||
};
|
||||
|
||||
struct FmhaFwdGroupBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs
|
||||
{
|
||||
const int32_t* block_scale_seqstart_q_ptr;
|
||||
const int32_t* block_scale_seqstart_k_ptr;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonLSEKargs
|
||||
{
|
||||
void* lse_ptr = nullptr;
|
||||
@@ -266,12 +243,9 @@ struct FmhaFwdKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
|
||||
FmhaFwdBatchBlockScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
{
|
||||
@@ -295,12 +269,9 @@ struct FmhaFwdKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
|
||||
FmhaFwdGroupBlockScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
|
||||
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
|
||||
@@ -357,9 +328,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
@@ -367,9 +335,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_q_descale,
|
||||
ck_tile::index_t batch_stride_k_descale,
|
||||
ck_tile::index_t batch_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -378,8 +343,6 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -450,23 +413,6 @@ struct FmhaFwdKernel
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
|
||||
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
|
||||
kargs.batch_stride_q_descale = batch_stride_q_descale;
|
||||
kargs.batch_stride_k_descale = batch_stride_k_descale;
|
||||
kargs.batch_stride_v_descale = batch_stride_v_descale;
|
||||
|
||||
kargs.block_scale_size_q = block_scale_size_q;
|
||||
kargs.block_scale_size_kv = block_scale_size_kv;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -532,9 +478,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
@@ -542,9 +485,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_q_descale,
|
||||
ck_tile::index_t batch_stride_k_descale,
|
||||
ck_tile::index_t batch_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -552,8 +492,6 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -590,9 +528,6 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_descale,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -600,9 +535,6 @@ struct FmhaFwdKernel
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
batch_stride_q_descale,
|
||||
batch_stride_k_descale,
|
||||
batch_stride_v_descale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
@@ -610,8 +542,6 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_size_q,
|
||||
block_scale_size_kv,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
@@ -651,9 +581,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
@@ -661,9 +588,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_q_descale,
|
||||
ck_tile::index_t batch_stride_k_descale,
|
||||
ck_tile::index_t batch_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -671,8 +595,6 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -709,9 +631,6 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_descale,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
@@ -719,9 +638,6 @@ struct FmhaFwdKernel
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
batch_stride_q_descale,
|
||||
batch_stride_k_descale,
|
||||
batch_stride_v_descale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
@@ -729,8 +645,6 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_size_q,
|
||||
block_scale_size_kv,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
@@ -752,8 +666,6 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* block_scale_seqstart_q_ptr,
|
||||
const void* block_scale_seqstart_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -773,9 +685,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -785,8 +694,6 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -856,24 +763,6 @@ struct FmhaFwdKernel
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
|
||||
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
|
||||
kargs.block_scale_size_q = block_scale_size_q;
|
||||
kargs.block_scale_size_kv = block_scale_size_kv;
|
||||
|
||||
kargs.block_scale_seqstart_q_ptr =
|
||||
reinterpret_cast<const int32_t*>(block_scale_seqstart_q_ptr);
|
||||
kargs.block_scale_seqstart_k_ptr =
|
||||
reinterpret_cast<const int32_t*>(block_scale_seqstart_k_ptr);
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -925,8 +814,6 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* block_scale_seqstart_q_ptr,
|
||||
const void* block_scale_seqstart_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -946,9 +833,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -957,8 +841,6 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -978,8 +860,6 @@ struct FmhaFwdKernel
|
||||
seqstart_k_ptr,
|
||||
seqlen_q_ptr,
|
||||
seqlen_k_ptr,
|
||||
block_scale_seqstart_q_ptr,
|
||||
block_scale_seqstart_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -999,9 +879,6 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_descale,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
@@ -1010,8 +887,6 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_size_q,
|
||||
block_scale_size_kv,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
@@ -1034,8 +909,6 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* block_scale_seqstart_q_ptr,
|
||||
const void* block_scale_seqstart_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -1055,9 +928,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_q_descale,
|
||||
ck_tile::index_t nhead_stride_k_descale,
|
||||
ck_tile::index_t nhead_stride_v_descale,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
@@ -1066,8 +936,6 @@ struct FmhaFwdKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
@@ -1087,8 +955,6 @@ struct FmhaFwdKernel
|
||||
seqstart_k_ptr,
|
||||
seqlen_q_ptr,
|
||||
seqlen_k_ptr,
|
||||
block_scale_seqstart_q_ptr,
|
||||
block_scale_seqstart_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -1108,9 +974,6 @@ struct FmhaFwdKernel
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
nhead_stride_q_descale,
|
||||
nhead_stride_k_descale,
|
||||
nhead_stride_v_descale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
@@ -1119,8 +982,6 @@ struct FmhaFwdKernel
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
block_scale_size_q,
|
||||
block_scale_size_kv,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
@@ -1250,16 +1111,13 @@ struct FmhaFwdKernel
|
||||
const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_randval = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
long_index_t batch_offset_q_descale = 0;
|
||||
long_index_t batch_offset_k_descale = 0;
|
||||
long_index_t batch_offset_v_descale = 0;
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_randval = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
const float sink_value =
|
||||
kargs.sink_ptr != nullptr
|
||||
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
|
||||
@@ -1295,14 +1153,6 @@ struct FmhaFwdKernel
|
||||
{
|
||||
batch_offset_randval = query_start * kargs.stride_randval;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
const long_index_t bquery_start = kargs.block_scale_seqstart_q_ptr[i_batch];
|
||||
const long_index_t bkey_start = kargs.block_scale_seqstart_k_ptr[i_batch];
|
||||
batch_offset_q_descale = bquery_start;
|
||||
batch_offset_k_descale = bkey_start;
|
||||
batch_offset_v_descale = bkey_start;
|
||||
}
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
|
||||
// real logical lengths (exclude PAD)
|
||||
@@ -1370,15 +1220,6 @@ struct FmhaFwdKernel
|
||||
batch_offset_randval =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
batch_offset_q_descale =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_q_descale;
|
||||
batch_offset_k_descale =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_k_descale;
|
||||
batch_offset_v_descale =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_v_descale;
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
|
||||
// If cumulative seqlen pointers are provided, override per-batch effective lengths
|
||||
@@ -1699,8 +1540,7 @@ struct FmhaFwdKernel
|
||||
}();
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
|
||||
auto o_acc_tile = [&, i_nhead_ = i_nhead]() {
|
||||
auto o_acc_tile = [&]() {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
// TODO - move global load of descale to pipeline
|
||||
@@ -1741,62 +1581,8 @@ struct FmhaFwdKernel
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
sink_value);
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
const float* q_descale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.q_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_q_descale +
|
||||
batch_offset_q_descale;
|
||||
const float* k_descale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.k_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_k_descale +
|
||||
batch_offset_k_descale;
|
||||
const float* v_descale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.v_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_v_descale +
|
||||
batch_offset_v_descale;
|
||||
|
||||
size_t idx = i_m0 / kargs.block_scale_size_q;
|
||||
float q_descale = q_descale_ptr[idx];
|
||||
// BLOCKSCALE: P is scaled in exp2(x+shift) where shift=7 or 8
|
||||
// Both P and rowsum are scaled by 2^shift, canceling in normalization
|
||||
// No additional scaling needed in p_compute_element_func or o_acc_element_func
|
||||
|
||||
return FmhaPipeline{}(
|
||||
q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
scales<float>(q_descale), // s_acc_element_func
|
||||
identity{}, // p_compute_element_func - No scaling (done in exp2)
|
||||
identity{}, // o_acc_element_func - No dequant needed (canceled by rowsum)
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
kargs.block_scale_size_kv,
|
||||
sink_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
|
||||
@@ -57,13 +57,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
// For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift]
|
||||
static constexpr float OCP_FP8_SHIFT = 8.0f;
|
||||
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
|
||||
|
||||
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
|
||||
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
|
||||
|
||||
@@ -172,9 +167,6 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout,
|
||||
const float* k_descale_ptr,
|
||||
const float* v_descale_ptr,
|
||||
const index_t block_scale_size_kv,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
@@ -366,13 +358,6 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static_assert(1 <= k1_loops);
|
||||
do
|
||||
{
|
||||
float k_descale = 1.0f;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
// K and V share the same seqlen_k position within a block
|
||||
const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv;
|
||||
k_descale = k_descale_ptr[kv_idx];
|
||||
}
|
||||
// STAGE 1, QK gemm
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window.get_bottom_tensor_view(),
|
||||
@@ -442,20 +427,11 @@ struct BlockFmhaPipelineQRKSVS
|
||||
k_lds_window);
|
||||
schedule_gemm0();
|
||||
}
|
||||
// dequant
|
||||
auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
return s_acc_element_func * k_descale;
|
||||
}
|
||||
else
|
||||
return s_acc_element_func;
|
||||
}();
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
@@ -473,7 +449,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
@@ -490,7 +466,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
auto apply_logits_transform =
|
||||
@@ -595,21 +571,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
// For BLOCKSCALE: precompute (m - shift) once per row
|
||||
// Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift))
|
||||
// else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift))
|
||||
auto validated_m = get_validated_m(m[i_idx]);
|
||||
auto row_max = scale_s * validated_m;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap
|
||||
row_max -= OCP_FP8_SHIFT; // for else branch
|
||||
#else
|
||||
validated_m -= FNUZ_FP8_SHIFT;
|
||||
row_max -= FNUZ_FP8_SHIFT;
|
||||
#endif
|
||||
}
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
@@ -617,13 +579,13 @@ struct BlockFmhaPipelineQRKSVS
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m);
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m);
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -714,39 +676,18 @@ struct BlockFmhaPipelineQRKSVS
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
|
||||
}
|
||||
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
|
||||
float v_descale = 1.0f;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
// K and V share the same seqlen_k position within a block
|
||||
const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv;
|
||||
v_descale = v_descale_ptr[kv_idx];
|
||||
}
|
||||
// STAGE 3, KV gemm
|
||||
auto o_acc0 = decltype(o_acc){};
|
||||
clear_tile(o_acc0);
|
||||
|
||||
auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
return o_acc0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return o_acc;
|
||||
}
|
||||
}();
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc_,
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_window);
|
||||
@@ -781,16 +722,11 @@ struct BlockFmhaPipelineQRKSVS
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc_,
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0);
|
||||
}
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// store lse
|
||||
@@ -910,9 +846,6 @@ struct BlockFmhaPipelineQRKSVS
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -46,7 +46,6 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
@@ -65,10 +64,6 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
// For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift]
|
||||
static constexpr float OCP_FP8_SHIFT = 8.0f;
|
||||
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
!kHasLogitsSoftCap)) ||
|
||||
@@ -195,9 +190,6 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout,
|
||||
const float* k_descale_ptr,
|
||||
const float* v_descale_ptr,
|
||||
const index_t block_scale_size_kv,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
@@ -411,13 +403,6 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
// main loop
|
||||
do
|
||||
{
|
||||
float k_descale = 1.0f;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
// K and V share the same seqlen_k position within a block
|
||||
const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv;
|
||||
k_descale = k_descale_ptr[kv_idx];
|
||||
}
|
||||
// STAGE 1, QK gemm
|
||||
clear_tile(s_acc); // initialize C
|
||||
if constexpr(k0_loops > 1)
|
||||
@@ -464,20 +449,11 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(1);
|
||||
// dequant
|
||||
auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
return s_acc_element_func * k_descale;
|
||||
}
|
||||
else
|
||||
return s_acc_element_func;
|
||||
}();
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
@@ -495,7 +471,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
@@ -512,7 +488,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
}
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
auto apply_logits_transform =
|
||||
@@ -654,21 +630,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
// For BLOCKSCALE: precompute (m - shift) once per row
|
||||
// Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift))
|
||||
// else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift))
|
||||
auto validated_m = get_validated_m(m[i_idx]);
|
||||
auto row_max = scale_s * validated_m;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap
|
||||
row_max -= OCP_FP8_SHIFT; // for else branch
|
||||
#else
|
||||
validated_m -= FNUZ_FP8_SHIFT;
|
||||
row_max -= FNUZ_FP8_SHIFT;
|
||||
#endif
|
||||
}
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
@@ -676,13 +638,13 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m);
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m);
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -773,27 +735,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
#endif
|
||||
}();
|
||||
|
||||
float v_descale = 1.0f;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
// K and V share the same seqlen_k position within a block
|
||||
const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv;
|
||||
v_descale = v_descale_ptr[kv_idx];
|
||||
}
|
||||
// STAGE 3, KV gemm
|
||||
auto o_acc0 = decltype(o_acc){};
|
||||
clear_tile(o_acc0);
|
||||
|
||||
auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
return o_acc0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return o_acc;
|
||||
}
|
||||
}();
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
@@ -803,7 +745,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
}
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc_,
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
get_slice_tile(
|
||||
@@ -866,19 +808,13 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(
|
||||
o_acc_,
|
||||
o_acc,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
|
||||
}
|
||||
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0);
|
||||
}
|
||||
} while(i_total_loops < num_total_loop);
|
||||
|
||||
// store lse
|
||||
@@ -986,9 +922,6 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user