hack bias for padding

This commit is contained in:
Ye Wang
2025-03-03 11:23:32 -06:00
parent 2b34a25675
commit 7e0bed7a82
2 changed files with 99 additions and 30 deletions

View File

@@ -304,10 +304,54 @@ bool run(const ck_tile::ArgParser& arg_parser)
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
ck_tile::HostTensor<VDataType> v_host(
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v));
// use Bx1xSxS with 0/-inf for padding mask
// bias_enum::elementwise_bias now serve as padding indicator
ck_tile::HostTensor<BiasDataType> bias_host(
bias.type == bias_enum::elementwise_bias
? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k)
? std::array<ck_tile::index_t, 4>{batch, 1, shape_seqlen_q, shape_seqlen_k}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
// require both seqlens_q and seqlens_kv for padding in bwd
ck_tile::HostTensor<ck_tile::index_t> seqlen_q_host(std::array<ck_tile::index_t, 1>{batch});
ck_tile::HostTensor<ck_tile::index_t> seqlen_kv_host(std::array<ck_tile::index_t, 1>{batch});
if(bias.type==bias_enum::elementwise_bias){
// initialize the bias_host to all 0
bias_host.SetZero();
// generate random seqlens for q and kv in each batch
ck_tile::FillUniformDistributionIntegerValue<ck_tile::index_t>{1.f, max_seqlen_q-1.f, seed}(seqlen_q_host);
ck_tile::FillUniformDistributionIntegerValue<ck_tile::index_t>{1.f, max_seqlen_k-1.f, seed}(seqlen_kv_host);
// fill in -inf to bias_host each batch
for(ck_tile::index_t b_i = 0; b_i < batch; b_i++){
// debug printing
std::cout<<"seqlen_q["<<b_i<<"]: "<<seqlen_q_host(b_i)<<std::endl;
std::cout<<"seqlen_kv["<<b_i<<"]: "<<seqlen_kv_host(b_i)<<std::endl;
// bias are like
// 0, 0, 0, ..., -inf, -inf, -inf
// 0, 0, 0, ..., -inf, -inf, -inf
// 0, 0, 0, ..., -inf, -inf, -inf
// ...
// -inf, -inf, -inf, ... -inf,
// -inf, -inf, -inf, ... -inf,
// fill in the right part
for(ck_tile::index_t s_kv_i = seqlen_kv_host(b_i); s_kv_i < max_seqlen_k; s_kv_i++){
for(ck_tile::index_t s_q_j = 0; s_q_j < max_seqlen_q; s_q_j++){
bias_host(b_i, 0, s_q_j, s_kv_i) = -std::numeric_limits<BiasDataType>::infinity();
}
}
// fill in the bottom part
for(ck_tile::index_t s_q_i = seqlen_q_host(b_i); s_q_i < max_seqlen_q; s_q_i++){
for(ck_tile::index_t s_kv_j = 0; s_kv_j < max_seqlen_k; s_kv_j++){
bias_host(b_i, 0, s_q_i, s_kv_j) = -std::numeric_limits<BiasDataType>::infinity();
}
}
}
}
ck_tile::HostTensor<AccDataType> alibi_slope_host(
bias.type == bias_enum::alibi
? (bias.rank_info == 0 ? std::array<ck_tile::index_t, 2>{1, nhead}
@@ -342,7 +386,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
if(bias.type!=bias_enum::elementwise_bias){
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
}
ck_tile::FillUniformDistributionIntegerValue<OGradDataType>{-2.f, 2.f, seed}(do_host);
}
else if(init_method == 1)
@@ -350,7 +396,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
if(bias.type!=bias_enum::elementwise_bias){
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
}
ck_tile::FillUniformDistribution<OGradDataType>{0.f, 1.f, seed}(do_host);
}
else if(init_method == 2)
@@ -358,7 +406,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillTrigValue<QDataType>{}(q_host);
ck_tile::FillTrigValue<KDataType>{}(k_host);
ck_tile::FillTrigValue<VDataType>{}(v_host);
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
if(bias.type!=bias_enum::elementwise_bias){
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
}
ck_tile::FillTrigValue<OGradDataType>{}(do_host);
}
if(bias.type == bias_enum::alibi)
@@ -453,9 +503,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
v3_bf16_cvt};
auto fmha_args = [&]() {
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
/// 'nhead_stride_bias' are 0.
/// NOTE: we broadcast bias from [batch, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence the 'nhead_stride_bias' is 0.
// setup stride_* arguments
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
@@ -484,7 +533,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_v = (nhead_k * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_bias = 0;
const ck_tile::index_t batch_stride_bias = max_seqlen_q*max_seqlen_k;
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
@@ -673,10 +722,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// elementwise bias
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
// clang-format off
if(i_perm)
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); });
else
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); });
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(b, 0, i[1] + query_offset, i[2]); });
// clang-format on
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,

View File

@@ -618,11 +618,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v)
: get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_knew))
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
// use Bx1x1xS bias with 0/-inf for padding mask
// bias_enum::elementwise_bias now serve as padding indicator
ck_tile::HostTensor<BiasDataType> bias_host(
bias.type == bias_enum::elementwise_bias
? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
? std::array<ck_tile::index_t, 4>{batch, 1, 1, shape_seqlen_k}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
// seqlens_kv is enough for validation as seqlens_q is implicit there
ck_tile::HostTensor<ck_tile::index_t> seqlen_kv_host(std::array<ck_tile::index_t, 1>{batch});
if(bias.type==bias_enum::elementwise_bias){
bias_host.SetZero();
ck_tile::FillUniformDistributionIntegerValue<ck_tile::index_t>{1.f, max_seqlen_k-1.f, seed}(seqlen_kv_host);
for(ck_tile::index_t b_i = 0; b_i < batch; b_i++){
std::cout<<"seqlen_kv["<<b_i<<"]: "<<seqlen_kv_host(b_i)<<std::endl;
for(ck_tile::index_t s_kv_i = seqlen_kv_host(b_i); s_kv_i < max_seqlen_k; s_kv_i++){
bias_host(b_i, 0, 0, s_kv_i) = -std::numeric_limits<BiasDataType>::infinity();
}
}
}
ck_tile::HostTensor<SaccDataType> alibi_slope_host(
bias.type == bias_enum::alibi
? (bias.rank_info == 0 ? std::array<ck_tile::index_t, 2>{1, nhead}
@@ -672,7 +687,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(knew_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(vnew_host);
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
if(bias.type!=bias_enum::elementwise_bias){
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
}
}
else if(init_method == "ni")
{
@@ -681,7 +698,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(knew_host);
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(vnew_host);
ck_tile::FillNormalDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
if(bias.type!=bias_enum::elementwise_bias){
ck_tile::FillNormalDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
}
}
else if(init_method == "uf" || init_method == "1")
{
@@ -690,7 +709,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(knew_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(vnew_host);
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
if(bias.type!=bias_enum::elementwise_bias){
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
}
}
else if(init_method == "nf")
{
@@ -699,7 +720,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(knew_host);
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(v_host);
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(vnew_host);
ck_tile::FillNormalDistribution<BiasDataType>{0.f, 3.f, seed}(bias_host);
if(bias.type!=bias_enum::elementwise_bias){
ck_tile::FillNormalDistribution<BiasDataType>{0.f, 3.f, seed}(bias_host);
}
}
else if(init_method == "tf" || init_method == "2")
{
@@ -708,7 +731,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillTrigValue<KDataType>{}(knew_host);
ck_tile::FillTrigValue<VDataType>{}(v_host);
ck_tile::FillTrigValue<VDataType>{}(vnew_host);
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
if(bias.type!=bias_enum::elementwise_bias){
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
}
}
else if(init_method == "ufq" || init_method == "uf:q" ||
init_method == "3") // suitable for fp8 quantization
@@ -722,7 +747,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
// bias_fp8 = qscale_bias * bias_fp32
float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k);
// Assume bias is in [-1.f, 1.f] in original fp32
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
if(bias.type!=bias_enum::elementwise_bias){
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
}
}
if(bias.type == bias_enum::alibi)
{
@@ -864,9 +891,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
const auto init_args = [&, k_paddings_ = seqlen_kpads](auto& args) {
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// NOTE: we broadcast bias from [batch, 1, 1, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
/// 'nhead_stride_bias' are 0.
/// 'stride_bias' are 0.
// setup stride_* arguments
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
@@ -884,7 +911,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else
return i_perm ? seqlen_knew : nhead_k * seqlen_knew;
}();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_bias = 0;
const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_o_acc = (hdim_v);
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
@@ -908,8 +935,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else
return i_perm ? hdim_v * seqlen_knew : seqlen_knew;
}();
const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
const ck_tile::index_t nhead_stride_bias = 0;
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
@@ -925,7 +951,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
(0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
: (nhead_k * hdim_v * shape_seqlen_k));
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_bias = shape_seqlen_k;
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q);
@@ -1378,12 +1404,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(bias.type == bias_enum::elementwise_bias)
{
// elementwise bias
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, 1, real_seqlen_k});
// clang-format off
if(i_perm)
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); });
else
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); });
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(b_idx, 0, 0, i[2] + key_offset); });
// clang-format on
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,