mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
enable prefill overload operator().
This commit is contained in:
@@ -582,9 +582,6 @@ struct FmhaFwdDecodeKernel
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
@@ -948,23 +945,51 @@ struct FmhaFwdDecodeKernel
|
||||
}();
|
||||
|
||||
auto o_acc_tile = [&, i_split_ = i_split]() {
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
// k_page_block_navigator, // Remove it
|
||||
v_dram_window,
|
||||
// v_page_block_navigator, // Remove it
|
||||
bias_dram_window,
|
||||
lse_acc_dram_window,
|
||||
kargs.num_splits, // Remove it
|
||||
i_split_, // Remove it
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
// variant, // Remove it
|
||||
// variant_params, // Remove it
|
||||
// block_indices, // Remove it
|
||||
// kv_l2p_offset, // Remove it
|
||||
smem_ptr);
|
||||
if constexpr(FmhaPipeline::kM0 == 128)
|
||||
{
|
||||
// allocate double lds
|
||||
// add __restrict__ here to avoid aliasing
|
||||
__shared__ char
|
||||
smem_ptrk0[FmhaPipeline::Policy::
|
||||
template GetSmemSizeK<typename FmhaPipeline::Problem, true>()];
|
||||
__shared__ char
|
||||
smem_ptrk1[FmhaPipeline::Policy::
|
||||
template GetSmemSizeK<typename FmhaPipeline::Problem, true>()];
|
||||
__shared__ char smem_ptrv0
|
||||
[FmhaPipeline::Policy::template GetSmemSizeV<typename FmhaPipeline::Problem>()];
|
||||
__shared__ char smem_ptrv1
|
||||
[FmhaPipeline::Policy::template GetSmemSizeV<typename FmhaPipeline::Problem>()];
|
||||
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
lse_acc_dram_window,
|
||||
kargs.num_splits, // Remove it
|
||||
i_split_, // Remove it
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
smem_ptrk0,
|
||||
smem_ptrk1,
|
||||
smem_ptrv0,
|
||||
smem_ptrv1);
|
||||
}
|
||||
else
|
||||
{
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
lse_acc_dram_window,
|
||||
kargs.num_splits, // Remove it
|
||||
i_split_, // Remove it
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
smem_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
// Oacc DRAM and Oacc DRAM window
|
||||
|
||||
@@ -14,6 +14,9 @@ namespace ck_tile {
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy>
|
||||
struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
@@ -43,8 +46,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
static constexpr index_t kNWarp = BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
static constexpr index_t kNXdl = BlockFmhaShape::Gemm0WarpTile::at(number<1>{});
|
||||
static constexpr index_t kNWarp = BlockFmhaShape::Gemm0BlockWarps::at(I1);
|
||||
static constexpr index_t kNXdl = BlockFmhaShape::Gemm0WarpTile::at(I1);
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
@@ -124,6 +127,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
// Decode
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
@@ -149,15 +153,14 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kSubQKHeaddim ==
|
||||
QDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
kSubQKHeaddim == QDramBlockWindowTmp{}.get_window_lengths()[I1] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[I1] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[I1] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I1],
|
||||
"wrong!");
|
||||
ignore = bias_dram_block_window_tmp;
|
||||
ignore = position_encoding;
|
||||
@@ -192,7 +195,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
q_origin.at(I0), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
||||
@@ -287,20 +290,10 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds,
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
make_tuple(number<kK1>{}, number<kN1>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeVRegTileDistribution<Problem>());
|
||||
|
||||
// Bias tile in LDS
|
||||
// const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
// auto bias_dram_window =
|
||||
// make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
// bias_dram_block_window_tmp.get_window_lengths(),
|
||||
// {bias_origin.at(number<0>{}),
|
||||
// logical_seqlen_k_start - (physical_seqlen_k_start -
|
||||
// aligned_physical_seqlen_k_start)}, // M/N
|
||||
// Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
block_sync_lds_direct_load<0>();
|
||||
auto q_tile = load_tile(q_lds_read_window);
|
||||
|
||||
@@ -312,7 +305,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
static_assert(1 <= k0_loops);
|
||||
static_assert(1 == k1_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
|
||||
block_sync_lds();
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
@@ -326,40 +319,31 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile
|
||||
|
||||
// move V tile windows
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
move_tile_window(v_dram_window, {kN0, 0});
|
||||
|
||||
// STAGE 1, QK gemm
|
||||
clear_tile(s_acc); // initialize C
|
||||
|
||||
// if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
// {
|
||||
// __builtin_amdgcn_sched_barrier(
|
||||
// 0); // prevent from messing up the order of global loads
|
||||
// }
|
||||
// const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
// if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
// {
|
||||
// __builtin_amdgcn_sched_barrier(
|
||||
// 0); // prevent from messing up the order of global loads
|
||||
// }
|
||||
if constexpr(1 < k0_loops)
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
if constexpr(i_k0 == 0){
|
||||
if constexpr(i_k0 == 0)
|
||||
{
|
||||
block_sync_lds_direct_load<v_vmem_insts>();
|
||||
}
|
||||
else{
|
||||
else
|
||||
{
|
||||
block_sync_lds_direct_load<0>();
|
||||
}
|
||||
|
||||
auto k_tile = load_tile(k_lds_read_window);
|
||||
|
||||
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
k_tile);
|
||||
|
||||
|
||||
// loop over along the [K]ey head dimension
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
block_sync_lds();
|
||||
@@ -369,56 +353,23 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
move_tile_window(k_dram_window, {0, -kK0 * (k0_loops - 1)});
|
||||
}
|
||||
|
||||
if constexpr(k0_loops==1){
|
||||
if constexpr(k0_loops == 1)
|
||||
{
|
||||
block_sync_lds_direct_load<v_vmem_insts>();
|
||||
}
|
||||
else{
|
||||
else
|
||||
{
|
||||
block_sync_lds_direct_load<0>();
|
||||
}
|
||||
|
||||
auto k_tile = load_tile(k_lds_read_window);
|
||||
|
||||
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
k_tile);
|
||||
|
||||
// // STAGE 2, scale_s, add bias, mask, softmax
|
||||
// if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
// {
|
||||
// tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
// tile_elementwise_inout(
|
||||
// [&](auto& x, const auto& y) {
|
||||
// x += log2e_v<SaccDataType> *
|
||||
// type_convert<SaccDataType>(bias_element_func(y));
|
||||
// },
|
||||
// s_acc,
|
||||
// bias_tile);
|
||||
// }
|
||||
// else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
// {
|
||||
// const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
// i_page_block_k, k_dram_block_window.get_window_origin());
|
||||
// constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
// sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
// const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
// s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
// const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
// const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
// s_acc(i_j_idx) *= scale_s;
|
||||
// // position_encoding accept only logical coordinates, do conversion here
|
||||
// position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset);
|
||||
// });
|
||||
// });
|
||||
// }
|
||||
// move_tile_window(bias_dram_window, {0, kN0});
|
||||
|
||||
/// TODO: only check in first/last iteration without increasing code size
|
||||
if constexpr(kHasUnevenSplits)
|
||||
{
|
||||
if(i_total_loops == (num_total_loop - 1))
|
||||
@@ -429,8 +380,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
[&,
|
||||
physical_seqlen_k_start_ = physical_seqlen_k_start,
|
||||
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
const auto col = k_origin.at(I0) + tile_idx.at(I1);
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return col < physical_seqlen_k_start_ ||
|
||||
@@ -447,19 +397,15 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = make_tuple(kN0 * i_total_loops, 0);
|
||||
// const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
// i_page_block_k, k_dram_block_window.get_window_origin());
|
||||
// mask accept only logical coordinates, do conversion here
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
|
||||
bool need_perpixel_check =
|
||||
mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number<kM0>{}, number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
const auto row = q_origin.at(I0) + tile_idx.at(I0);
|
||||
const auto col = k_origin.at(I0) + tile_idx.at(I1);
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
@@ -519,10 +465,10 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
};
|
||||
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
sweep_tile_span(p_spans[I1], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
@@ -554,7 +500,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(o_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
@@ -576,24 +522,39 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
}
|
||||
}();
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
sweep_tile_span(o_spans[I1], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// FIXME: this use different equation from FA v2 paper,
|
||||
// but produce correc result.
|
||||
// Is the equation wrong?
|
||||
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds_direct_load<k_vmem_insts>();
|
||||
|
||||
auto v_tile = load_tile_transpose(v_lds_read_window);
|
||||
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(
|
||||
p_tile, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, k1_loops * kK1>{}),
|
||||
get_slice_tile(
|
||||
v_tile, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kN1, k1_loops * kK1>{}));
|
||||
if constexpr(1 < k1_loops)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(p_tile,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_tile);
|
||||
|
||||
// loop over along the [V]alue Sequence length
|
||||
move_tile_window(v_lds_read_window, {kK1, 0});
|
||||
v_tile = load_tile_transpose(v_lds_read_window);
|
||||
});
|
||||
// move back to the origin
|
||||
move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0});
|
||||
}
|
||||
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(p_tile,
|
||||
sequence<0, (k1_loops - 1) * kK1>{},
|
||||
sequence<kM0, k1_loops * kK1>{}),
|
||||
v_tile);
|
||||
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
@@ -603,7 +564,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
auto lse_acc = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
|
||||
sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
sweep_tile_span(lse_acc_spans[I0], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
@@ -632,7 +593,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
// finally, O
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(o_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
@@ -643,7 +604,519 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
else
|
||||
return 1 / l[i_idx];
|
||||
}();
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
sweep_tile_span(o_spans[I1], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
// Prefill, double lds
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEaccDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
|
||||
index_t num_splits,
|
||||
index_t i_split,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* __restrict__ smem_ptrk0,
|
||||
void* __restrict__ smem_ptrk1,
|
||||
void* __restrict__ smem_ptrv0,
|
||||
void* __restrict__ smem_ptrv1) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
kSubQKHeaddim == QDramBlockWindowTmp{}.get_window_lengths()[I1] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[I1] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[I1] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I1],
|
||||
"wrong!");
|
||||
ignore = bias_dram_block_window_tmp;
|
||||
ignore = position_encoding;
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
|
||||
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
auto s_acc = SaccBlockTileType{};
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
|
||||
auto o_acc = OaccBlockTileType{};
|
||||
|
||||
// infer Sacc, S, P, M, L, Oacc type
|
||||
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(o_acc));
|
||||
|
||||
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
|
||||
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
|
||||
|
||||
// init M, L
|
||||
auto m = MLBlockTileType{};
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(I0), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
||||
{
|
||||
const index_t logical_num_total_loop =
|
||||
integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0);
|
||||
if(logical_num_total_loop <= 0)
|
||||
{
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
|
||||
if(get_thread_local_1d_id() < kM0)
|
||||
{
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
}
|
||||
|
||||
// Note: here occ are all cleard, return it
|
||||
// Note: q loaded but no fence, ignore it.
|
||||
return o_acc;
|
||||
}
|
||||
}
|
||||
|
||||
// Q tile in LDS
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution<Problem, true>());
|
||||
|
||||
// auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
// static_cast<QDataType*>(smem_ptrk0),
|
||||
// Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
|
||||
// auto q_lds_store_window = make_tile_window(
|
||||
// q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// auto q_lds_read_window =
|
||||
// make_tile_window(q_lds,
|
||||
// Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
// {0, 0},
|
||||
// Policy::template MakeQRegTileDistribution<Problem>());
|
||||
|
||||
// async_load_tile(q_lds_store_window, q_dram_window);
|
||||
auto q_tile = load_tile(q_dram_window);
|
||||
|
||||
// K tile in LDS
|
||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start;
|
||||
const index_t physical_seqlen_k_end = logical_seqlen_k_end;
|
||||
// make sure the first tile is completely located in page-block (page-block size should be
|
||||
// divisible by kN0)
|
||||
// relationship between each *_start variables: aligned_physical_seqlen_k_start <=
|
||||
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
|
||||
const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start;
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution<Problem, true>());
|
||||
|
||||
auto k_lds = make_tuple(make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<KDataType* __restrict__>(smem_ptrk0),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem, true>()),
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<KDataType* __restrict__>(smem_ptrk1),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem, true>()));
|
||||
|
||||
auto k_lds_write_windows =
|
||||
make_tuple(make_tile_window(
|
||||
k_lds.at(I0),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem, true>().get_lengths(),
|
||||
{0, 0}),
|
||||
make_tile_window(
|
||||
k_lds.at(I1),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem, true>().get_lengths(),
|
||||
{0, 0}));
|
||||
|
||||
auto k_lds_read_windows =
|
||||
make_tuple(make_tile_window(k_lds.at(I0),
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeKRegTileDistribution<Problem>()),
|
||||
make_tile_window(k_lds.at(I1),
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeKRegTileDistribution<Problem>()));
|
||||
|
||||
// S tile in LDS
|
||||
auto s_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<SaccDataType*>(reinterpret_cast<char*>(smem_ptrk0) +
|
||||
Policy::template GetSmemSizeK<Problem>()),
|
||||
Policy::template MakeSLdsBlockDescriptor<Problem>());
|
||||
auto s_write_lds_window = make_tile_window(
|
||||
s_lds, Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
auto s_read_lds_window =
|
||||
make_tile_window(s_lds,
|
||||
Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0},
|
||||
Policy::template MakeSRegTileDistribution<Problem>());
|
||||
|
||||
// V tile in LDS
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto v_lds = make_tuple(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv0)),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>()),
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv1)),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>()));
|
||||
|
||||
auto v_lds_write_windows = make_tuple(
|
||||
make_tile_window(v_lds.at(I0),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0}),
|
||||
make_tile_window(v_lds.at(I1),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0}));
|
||||
|
||||
auto v_lds_read_windows =
|
||||
make_tuple(make_tile_window(v_lds.at(I0),
|
||||
make_tuple(number<kK1>{}, number<kN1>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeVRegTileDistribution<Problem>()),
|
||||
make_tile_window(v_lds.at(I1),
|
||||
make_tuple(number<kK1>{}, number<kN1>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeVRegTileDistribution<Problem>()));
|
||||
|
||||
// block_sync_lds_direct_load<0>();
|
||||
// auto q_tile = load_tile(q_lds_read_window);
|
||||
|
||||
const index_t num_total_loop =
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
static_assert(1 <= k0_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
|
||||
// block_sync_lds();
|
||||
async_load_tile(k_lds_write_windows.at(I0), k_dram_window);
|
||||
async_load_tile(v_lds_write_windows.at(I0), v_dram_window);
|
||||
|
||||
constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
|
||||
constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
|
||||
|
||||
auto mainloop = [&](auto lds_write_buf, auto lds_read_buf) {
|
||||
auto k_lds_write_window = k_lds_write_windows.at(lds_write_buf);
|
||||
auto k_lds_read_window = k_lds_read_windows.at(lds_read_buf);
|
||||
auto v_lds_write_window = v_lds_write_windows.at(lds_write_buf);
|
||||
auto v_lds_read_window = v_lds_read_windows.at(lds_read_buf);
|
||||
|
||||
block_sync_lds();
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_window, {kN0, 0});
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
|
||||
// STAGE 1, QK gemm
|
||||
clear_tile(s_acc); // initialize C
|
||||
|
||||
block_sync_lds_direct_load<k_vmem_insts + v_vmem_insts>();
|
||||
auto k_tile = load_tile(k_lds_read_window);
|
||||
|
||||
if constexpr(1 < k0_loops)
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
k_tile);
|
||||
|
||||
// loop over along the [K]ey head dimension
|
||||
move_tile_window(k_lds_read_window, {0, kK0});
|
||||
k_tile = load_tile(k_lds_read_window);
|
||||
});
|
||||
// move back to the origin
|
||||
move_tile_window(k_lds_read_window, {0, -kK0 * (k0_loops - 1)});
|
||||
}
|
||||
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
k_tile);
|
||||
|
||||
if constexpr(kHasUnevenSplits)
|
||||
{
|
||||
if(i_total_loops == (num_total_loop - 1))
|
||||
{
|
||||
const auto k_origin = make_tuple(kN0 * i_total_loops, 0);
|
||||
set_tile_if(s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&,
|
||||
physical_seqlen_k_start_ = physical_seqlen_k_start,
|
||||
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
|
||||
const auto col = k_origin.at(I0) + tile_idx.at(I1);
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return col < physical_seqlen_k_start_ ||
|
||||
physical_seqlen_k_end_ <= col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return physical_seqlen_k_end_ <= col;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = make_tuple(kN0 * i_total_loops, 0);
|
||||
|
||||
bool need_perpixel_check =
|
||||
mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number<kM0>{}, number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(I0) + tile_idx.at(I0);
|
||||
const auto col = k_origin.at(I0) + tile_idx.at(I1);
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Gemm1
|
||||
auto s_new = [&]() {
|
||||
if constexpr(kNWarp > 1)
|
||||
{
|
||||
auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
|
||||
store_tile(s_write_lds_window, s);
|
||||
block_sync_lds();
|
||||
return load_tile(s_read_lds_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
}
|
||||
}();
|
||||
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s_new,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
|
||||
|
||||
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
|
||||
s_new.get_tile_distribution()); // Pcompute{j}
|
||||
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_m;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
sweep_tile_span(p_spans[I1], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
|
||||
auto p_tile = make_static_distributed_tensor<PDataType>(
|
||||
Policy::template MakePRegTileDistribution<Problem>());
|
||||
p_tile.get_thread_buffer() = cast_tile<PDataType>(p_compute).get_thread_buffer();
|
||||
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
return exp2(scale_s * m_old[i_idx] - row_max);
|
||||
}
|
||||
}
|
||||
}();
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[I1], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
move_tile_window(v_dram_window, {kN0, 0});
|
||||
async_load_tile(v_lds_write_window, v_dram_window);
|
||||
|
||||
block_sync_lds_direct_load<k_vmem_insts + v_vmem_insts>();
|
||||
// Will insert unexpected vmcnt(0) here, probably the aliasing issue.
|
||||
auto v_tile = load_tile_transpose(v_lds_read_window);
|
||||
|
||||
if constexpr(1 < k1_loops)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(p_tile,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_tile);
|
||||
|
||||
// loop over along the [V]alue Sequence length
|
||||
move_tile_window(v_lds_read_window, {kK1, 0});
|
||||
v_tile = load_tile_transpose(v_lds_read_window);
|
||||
});
|
||||
// move back to the origin
|
||||
move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0});
|
||||
}
|
||||
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(p_tile,
|
||||
sequence<0, (k1_loops - 1) * kK1>{},
|
||||
sequence<kM0, k1_loops * kK1>{}),
|
||||
v_tile);
|
||||
};
|
||||
|
||||
do
|
||||
{
|
||||
mainloop(I1, I0);
|
||||
i_total_loops++;
|
||||
if(i_total_loops == (num_total_loop))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
mainloop(I0, I1);
|
||||
i_total_loops++;
|
||||
} while(i_total_loops < num_total_loop);
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
// store lse acc
|
||||
auto lse_acc = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
|
||||
sweep_tile_span(lse_acc_spans[I0], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
else
|
||||
{
|
||||
lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if(get_thread_local_1d_id() < kM0)
|
||||
{
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
}
|
||||
|
||||
// finally, O
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
|
||||
}
|
||||
else
|
||||
return 1 / l[i_idx];
|
||||
}();
|
||||
sweep_tile_span(o_spans[I1], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
|
||||
@@ -68,7 +68,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType);
|
||||
|
||||
@@ -77,43 +77,77 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
return min(ElemPerThread, MaxVectorSize);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool BypassLDS = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
if constexpr(!BypassLDS)
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
|
||||
|
||||
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
|
||||
static_assert(0 < ElemPerThread);
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
|
||||
static_assert(0 < ElemPerThread);
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
|
||||
constexpr index_t KPerThread = kMaxVecLoad;
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread;
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
constexpr index_t KPerThread = kMaxVecLoad;
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread;
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto q_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
|
||||
|
||||
return q_block_dstr;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool LoadOnce = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
|
||||
{
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock =
|
||||
LoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
@@ -190,11 +224,12 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
return q_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool LoadOnce = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock =
|
||||
LoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
|
||||
@@ -211,7 +246,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
|
||||
@@ -335,7 +370,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType);
|
||||
|
||||
@@ -370,7 +405,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
@@ -502,10 +537,10 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
sizeof(typename Problem::QDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool LoadOnce = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
|
||||
{
|
||||
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
|
||||
return MakeKLdsBlockDescriptor<Problem, LoadOnce>().get_element_space_size() *
|
||||
sizeof(typename Problem::KDataType);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user