mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK_TILE] Change output accum tensor layout of fmha fwd split-kv & combine kernels (#1527)
* Use same layout for o_acc and o tensor * Use better param names in partitioner * Remove redundant kargs 'max_seqlen_q' * Use better param names in splitkv kernel * Add comment for additional kernel arguments * Sync empty loop early return logics between pipelines * Pass more arguments to cmake in scripts * Align backslashes * Fix wrong o_acc tensor view strides * Change o_acc layout if o_perm=0 * Handle whole row masked via attn_bias * Use use vector width = 1 for o_acc * Use more even split sizes
This commit is contained in:
@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
|
||||
{
|
||||
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
|
||||
|
||||
const index_t x_per_split = ck_tile::max(1, x_total / num_splits);
|
||||
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
|
||||
const index_t split_start = x_per_split * i_split;
|
||||
const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split);
|
||||
const index_t split_end = split_start + x_per_split;
|
||||
|
||||
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
|
||||
ck_tile::min(origin_end, split_end));
|
||||
|
||||
@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t num_splits;
|
||||
@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
ck_tile::index_t nhead_stride_o_acc;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
|
||||
ck_tile::index_t split_stride_lse_acc;
|
||||
ck_tile::index_t split_stride_o_acc;
|
||||
};
|
||||
@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
};
|
||||
|
||||
struct GroupModeKargs
|
||||
@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits,
|
||||
@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
o_acc_ptr,
|
||||
o_ptr,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
seqlen_q,
|
||||
hdim_v,
|
||||
num_splits,
|
||||
@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
nhead_stride_o,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
batch_stride_o,
|
||||
batch_stride_lse_acc};
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
batch_stride_o};
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
const void* seqstart_q_ptr,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits,
|
||||
@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
ck_tile::index_t nhead_stride_o_acc,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_o_acc,
|
||||
ck_tile::index_t split_stride_lse_acc,
|
||||
ck_tile::index_t split_stride_o_acc)
|
||||
{
|
||||
@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
o_acc_ptr,
|
||||
o_ptr,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
hdim_v,
|
||||
num_splits,
|
||||
@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
nhead_stride_o,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
{}, // placeholder for lse
|
||||
@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v)
|
||||
{
|
||||
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
|
||||
return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
const long_index_t batch_offset_o_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
long_index_t batch_offset_o_acc = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
batch_offset_o = query_start * kargs.row_stride_o;
|
||||
batch_offset_lse_acc = query_start;
|
||||
batch_offset_o_acc = query_start * kargs.row_stride_o_acc;
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = query_start;
|
||||
}
|
||||
|
||||
batch_offset_o = query_start * kargs.row_stride_o;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
|
||||
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
}
|
||||
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
auto o_acc_dram = [&]() {
|
||||
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_acc_ptr,
|
||||
make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
|
||||
number<FmhaPipeline::kAlignmentOacc>{},
|
||||
number<1>{});
|
||||
@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
|
||||
|
||||
const index_t padded_max_seqlen_q =
|
||||
const index_t padded_seqlen_q =
|
||||
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
|
||||
const index_t padded_hdim_v =
|
||||
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
|
||||
|
||||
return transform_tensor_view(
|
||||
o_acc_dram_view,
|
||||
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)),
|
||||
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_seqlen_q)),
|
||||
make_pass_through_transform(padded_hdim_v)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
identity{}, // lse_element_func
|
||||
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
|
||||
kargs.num_splits,
|
||||
kargs.max_seqlen_q,
|
||||
kargs.seqlen_q,
|
||||
smem_ptr);
|
||||
}
|
||||
else
|
||||
@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
o_acc_dram_window,
|
||||
lse_dram_window,
|
||||
kargs.num_splits,
|
||||
kargs.max_seqlen_q,
|
||||
kargs.seqlen_q,
|
||||
smem_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, kN1),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v, kN1),
|
||||
nhead,
|
||||
batch_size);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
|
||||
{
|
||||
// const index_t num_tile_m0 = seqlen_q / kM0;
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
|
||||
@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t nhead_stride_lse_acc;
|
||||
ck_tile::index_t nhead_stride_o_acc;
|
||||
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
|
||||
ck_tile::index_t split_stride_lse_acc;
|
||||
ck_tile::index_t split_stride_o_acc;
|
||||
};
|
||||
@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
};
|
||||
|
||||
struct GroupModeKargs
|
||||
@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_k; // only used for paged-kvcache
|
||||
ck_tile::index_t batch_stride_v; // only used for paged-kvcache
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
|
||||
@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel
|
||||
nhead_stride_v,
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v};
|
||||
batch_stride_v,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc};
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_lse_acc,
|
||||
ck_tile::index_t nhead_stride_o_acc,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_lse_acc,
|
||||
ck_tile::index_t batch_stride_o_acc,
|
||||
ck_tile::index_t batch_stride_k, // only used for paged-kvcache
|
||||
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
|
||||
ck_tile::index_t split_stride_lse_acc,
|
||||
ck_tile::index_t split_stride_o_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel
|
||||
nhead_stride_v,
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits)
|
||||
{
|
||||
return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, hdim_v, num_splits);
|
||||
return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v, num_splits);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
const long_index_t batch_offset_o_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
long_index_t batch_offset_o_acc = 0;
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
batch_offset_lse_acc = query_start;
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel
|
||||
batch_offset_bias = query_start * kargs.stride_bias + key_start;
|
||||
}
|
||||
|
||||
batch_offset_lse_acc = query_start;
|
||||
batch_offset_o_acc = query_start * kargs.stride_o_acc;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel
|
||||
batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
|
||||
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
|
||||
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel
|
||||
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.hdim_v, 1),
|
||||
number<FmhaPipeline::kAlignmentO>{},
|
||||
make_tuple(kargs.stride_o_acc, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
|
||||
@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q, kM0) *
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v, kN1),
|
||||
nhead * num_splits,
|
||||
batch_size);
|
||||
|
||||
@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const OaccElementFunction& o_acc_element_func,
|
||||
index_t num_splits,
|
||||
index_t max_seqlen_q,
|
||||
index_t seqlen_q,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
// lse_acc tile in LDS
|
||||
@@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
|
||||
clear_tile(o_acc);
|
||||
|
||||
const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0;
|
||||
const index_t padded_seqlen_q = integer_divide_ceil(seqlen_q, kM0) * kM0;
|
||||
|
||||
for(index_t i_split = 0; i_split < num_splits; ++i_split)
|
||||
{
|
||||
@@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
});
|
||||
}
|
||||
|
||||
move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0});
|
||||
move_tile_window(o_acc_dram_window, {padded_seqlen_q, 0});
|
||||
}
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
@@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const OaccDramBlockWindow& o_acc_dram_block_window,
|
||||
LSEDramBlockWindow& lse_dram_block_window,
|
||||
index_t num_splits,
|
||||
index_t max_seqlen_q,
|
||||
index_t seqlen_q,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
return operator()(lse_acc_dram_block_window,
|
||||
@@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
identity{},
|
||||
identity{},
|
||||
num_splits,
|
||||
max_seqlen_q,
|
||||
seqlen_q,
|
||||
smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking || kHasUnevenSplits)
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
||||
{
|
||||
const index_t original_num_total_loop =
|
||||
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user