mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
remove unnecessary features
This commit is contained in:
@@ -597,8 +597,8 @@ struct FmhaFwdDecodeKernel
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
long_index_t batch_offset_o_acc = 0;
|
||||
index_t kv_l2p_offset =
|
||||
0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache
|
||||
// index_t kv_l2p_offset =
|
||||
// 0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
@@ -648,7 +648,7 @@ struct FmhaFwdDecodeKernel
|
||||
if(kargs.is_gappy)
|
||||
{
|
||||
// seqstart_k_ptr has different meaning in this case
|
||||
kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
|
||||
// kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -809,66 +809,6 @@ struct FmhaFwdDecodeKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto k_page_block_navigator = [&, i_batch_ = i_batch]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
const auto* block_indices =
|
||||
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
|
||||
i_batch_ * kargs.batch_stride_block_table;
|
||||
const index_t num_blocks =
|
||||
integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
|
||||
|
||||
const long_index_t fixed_offset =
|
||||
static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k;
|
||||
|
||||
return make_page_block_navigator<const KDataType, 0>(
|
||||
kargs.k_ptr,
|
||||
kargs.batch_stride_k, // kcache page-block stride/size
|
||||
fixed_offset,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
kargs.page_block_size,
|
||||
k_dram,
|
||||
make_k_dram(nullptr,
|
||||
(kv_l2p_offset + kargs.seqlen_k) -
|
||||
(num_blocks - 1) * kargs.page_block_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_page_block_navigator(k_dram);
|
||||
}
|
||||
}();
|
||||
|
||||
auto v_page_block_navigator = [&, i_batch_ = i_batch]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
const auto* block_indices =
|
||||
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
|
||||
i_batch_ * kargs.batch_stride_block_table;
|
||||
const index_t num_blocks =
|
||||
integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
|
||||
|
||||
const long_index_t fixed_offset =
|
||||
static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v;
|
||||
|
||||
return make_page_block_navigator<const VDataType, 1>(
|
||||
kargs.v_ptr,
|
||||
kargs.batch_stride_v, // vcache page-block stride/size
|
||||
fixed_offset,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
kargs.page_block_size,
|
||||
v_dram,
|
||||
make_v_dram(nullptr,
|
||||
(kv_l2p_offset + kargs.seqlen_k) -
|
||||
(num_blocks - 1) * kargs.page_block_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_page_block_navigator(v_dram);
|
||||
}
|
||||
}();
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
[&]() {
|
||||
@@ -880,10 +820,11 @@ struct FmhaFwdDecodeKernel
|
||||
}(),
|
||||
{i_m0, 0});
|
||||
|
||||
auto k_dram_window_lengths =
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{});
|
||||
auto v_dram_window_lengths =
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{});
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0});
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram, make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), {0, 0});
|
||||
|
||||
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
|
||||
/// following copy capture of the 'i_nhead' if in C++20
|
||||
@@ -1006,70 +947,24 @@ struct FmhaFwdDecodeKernel
|
||||
}
|
||||
}();
|
||||
|
||||
AttentionVariant variant;
|
||||
const auto variant_params = [&] {
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
|
||||
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
|
||||
}
|
||||
}();
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead_k};
|
||||
|
||||
auto o_acc_tile = [&, i_split_ = i_split]() {
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window_lengths,
|
||||
k_page_block_navigator,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window_lengths,
|
||||
v_page_block_navigator,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
lse_acc_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{kargs.scale_p}, // p_compute_element_func
|
||||
identity{}, // o_acc_element_func
|
||||
kargs.num_splits,
|
||||
i_split_,
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window_lengths,
|
||||
k_page_block_navigator, // Remove it
|
||||
v_dram_window_lengths,
|
||||
v_page_block_navigator, // Remove it
|
||||
bias_dram_window,
|
||||
lse_acc_dram_window,
|
||||
kargs.num_splits, // Remove it
|
||||
i_split_, // Remove it
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant, // Remove it
|
||||
variant_params, // Remove it
|
||||
block_indices, // Remove it
|
||||
kv_l2p_offset, // Remove it
|
||||
smem_ptr);
|
||||
}
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
// k_page_block_navigator, // Remove it
|
||||
v_dram_window,
|
||||
// v_page_block_navigator, // Remove it
|
||||
bias_dram_window,
|
||||
lse_acc_dram_window,
|
||||
kargs.num_splits, // Remove it
|
||||
i_split_, // Remove it
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
// variant, // Remove it
|
||||
// variant_params, // Remove it
|
||||
// block_indices, // Remove it
|
||||
// kv_l2p_offset, // Remove it
|
||||
smem_ptr);
|
||||
}();
|
||||
|
||||
// Oacc DRAM and Oacc DRAM window
|
||||
|
||||
@@ -125,21 +125,15 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowLengths,
|
||||
typename KPageBlockNavigator,
|
||||
typename VDramBlockWindowLengths,
|
||||
typename VPageBlockNavigator,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEaccDramBlockWindowTmp,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
|
||||
const KPageBlockNavigator& k_page_block_navigator,
|
||||
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
|
||||
const VPageBlockNavigator& v_page_block_navigator,
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
|
||||
index_t num_splits,
|
||||
@@ -147,29 +141,26 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KPageBlockNavigator::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VPageBlockNavigator::DataType>>,
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kSubQKHeaddim ==
|
||||
QDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN0 == KDramBlockWindowLengths{}[number<0>{}] &&
|
||||
kK0 == KDramBlockWindowLengths{}[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowLengths{}[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowLengths{}[number<1>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
ignore = bias_dram_block_window_tmp;
|
||||
ignore = position_encoding;
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
|
||||
@@ -248,29 +239,16 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
async_load_tile(q_lds_store_window, q_dram_window);
|
||||
|
||||
// K tile in LDS
|
||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
|
||||
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
|
||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start;
|
||||
const index_t physical_seqlen_k_end = logical_seqlen_k_end;
|
||||
// make sure the first tile is completely located in page-block (page-block size should be
|
||||
// divisible by kN0)
|
||||
// relationship between each *_start variables: aligned_physical_seqlen_k_start <=
|
||||
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
|
||||
const index_t aligned_physical_seqlen_k_start =
|
||||
[&, physical_seqlen_k_start_ = physical_seqlen_k_start] {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return physical_seqlen_k_start_;
|
||||
}
|
||||
}();
|
||||
|
||||
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
||||
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
|
||||
const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start;
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window, Policy::template MakeKDramTileDistribution<Problem>());
|
||||
k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<KDataType*>(smem_ptr), Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
@@ -297,11 +275,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
Policy::template MakeSRegTileDistribution<Problem>());
|
||||
|
||||
// V tile in LDS
|
||||
auto [i_page_block_v, v_dram_block_window] = v_page_block_navigator.make_tile_window(
|
||||
v_dram_block_window_lengths, {0, aligned_physical_seqlen_k_start});
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window, Policy::template MakeVDramTileDistribution<Problem>());
|
||||
v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
|
||||
@@ -319,14 +294,14 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
Policy::template MakeVRegTileDistribution<Problem>());
|
||||
|
||||
// Bias tile in LDS
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}),
|
||||
logical_seqlen_k_start - (physical_seqlen_k_start -
|
||||
aligned_physical_seqlen_k_start)}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
// const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
// auto bias_dram_window =
|
||||
// make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
// bias_dram_block_window_tmp.get_window_lengths(),
|
||||
// {bias_origin.at(number<0>{}),
|
||||
// logical_seqlen_k_start - (physical_seqlen_k_start -
|
||||
// aligned_physical_seqlen_k_start)}, // M/N
|
||||
// Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
block_sync_lds_direct_load<0>();
|
||||
auto q_tile = load_tile(q_lds_read_window);
|
||||
@@ -352,17 +327,17 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
// STAGE 1, QK gemm
|
||||
clear_tile(s_acc); // initialize C
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
// if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
// {
|
||||
// __builtin_amdgcn_sched_barrier(
|
||||
// 0); // prevent from messing up the order of global loads
|
||||
// }
|
||||
// const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
// if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
// {
|
||||
// __builtin_amdgcn_sched_barrier(
|
||||
// 0); // prevent from messing up the order of global loads
|
||||
// }
|
||||
|
||||
block_sync_lds();
|
||||
async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile
|
||||
@@ -379,105 +354,74 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
k_tile);
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x += type_convert<SaccDataType>(bias_element_func(y));
|
||||
#else
|
||||
x += log2e_v<SaccDataType> *
|
||||
type_convert<SaccDataType>(bias_element_func(y));
|
||||
#endif
|
||||
},
|
||||
s_acc,
|
||||
bias_tile);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
i_page_block_k, k_dram_block_window.get_window_origin());
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
// // STAGE 2, scale_s, add bias, mask, softmax
|
||||
// if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
// {
|
||||
// tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
// tile_elementwise_inout(
|
||||
// [&](auto& x, const auto& y) {
|
||||
// x += log2e_v<SaccDataType> *
|
||||
// type_convert<SaccDataType>(bias_element_func(y));
|
||||
// },
|
||||
// s_acc,
|
||||
// bias_tile);
|
||||
// }
|
||||
// else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
// {
|
||||
// const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
// i_page_block_k, k_dram_block_window.get_window_origin());
|
||||
// constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
// sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
// const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
// s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
// const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
s_acc(i_j_idx) *= scale_s;
|
||||
// position_encoding accept only logical coordinates, do conversion here
|
||||
position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
auto apply_logits_transform =
|
||||
[&variant, &variant_params, &block_indices](auto& x) {
|
||||
x = variant.LogitsTransform(variant_params,
|
||||
variant.QueryTransform(variant_params, x),
|
||||
block_indices.batch_idx,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
};
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
#else
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
// s_acc(i_j_idx) *= scale_s;
|
||||
// // position_encoding accept only logical coordinates, do conversion here
|
||||
// position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset);
|
||||
// });
|
||||
// });
|
||||
// }
|
||||
// move_tile_window(bias_dram_window, {0, kN0});
|
||||
|
||||
/// TODO: only check in first/last iteration without increasing code size
|
||||
if constexpr(kHasUnevenSplits)
|
||||
{
|
||||
const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
i_page_block_k, k_dram_block_window.get_window_origin());
|
||||
set_tile_if(
|
||||
s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&,
|
||||
physical_seqlen_k_start_ = physical_seqlen_k_start,
|
||||
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return physical_seqlen_k_end_ <= col;
|
||||
}
|
||||
});
|
||||
if(i_total_loops == (num_total_loop - 1))
|
||||
{
|
||||
const auto k_origin = make_tuple(kN0 * i_total_loops, 0);
|
||||
set_tile_if(s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&,
|
||||
physical_seqlen_k_start_ = physical_seqlen_k_start,
|
||||
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return col < physical_seqlen_k_start_ ||
|
||||
physical_seqlen_k_end_ <= col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return physical_seqlen_k_end_ <= col;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
i_page_block_k, k_dram_block_window.get_window_origin());
|
||||
const auto k_origin = make_tuple(kN0 * i_total_loops, 0);
|
||||
// const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
// i_page_block_k, k_dram_block_window.get_window_origin());
|
||||
// mask accept only logical coordinates, do conversion here
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}) - kv_l2p_offset,
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
@@ -486,17 +430,13 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col - kv_l2p_offset);
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// move K tile windows after current status checked
|
||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0});
|
||||
|
||||
k_dram_window = make_tile_window(k_dram_block_window,
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
move_tile_window(k_dram_window, {kN0, 0});
|
||||
|
||||
block_sync_lds();
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
@@ -550,12 +490,9 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
@@ -572,9 +509,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max);
|
||||
}
|
||||
}
|
||||
#else
|
||||
p_compute(i_j_idx) = exp(s_new[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
@@ -591,8 +525,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
const auto tmp = [&]() {
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
@@ -611,9 +544,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
}
|
||||
}
|
||||
}();
|
||||
#else
|
||||
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
@@ -644,7 +574,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
|
||||
sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
@@ -661,9 +590,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
}
|
||||
#else
|
||||
lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
|
||||
#endif
|
||||
});
|
||||
|
||||
if(get_thread_local_1d_id() < kM0)
|
||||
|
||||
@@ -271,17 +271,22 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
WGAttrNumAccessEnum::Double>;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 &&
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) ||
|
||||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 &&
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16))
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::PDataType,
|
||||
|
||||
Reference in New Issue
Block a user