mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
hack bias for padding
This commit is contained in:
@@ -304,10 +304,54 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
ck_tile::HostTensor<VDataType> v_host(
|
||||
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v));
|
||||
// use Bx1xSxS with 0/-inf for padding mask
|
||||
// bias_enum::elementwise_bias now serve as padding indicator
|
||||
ck_tile::HostTensor<BiasDataType> bias_host(
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k)
|
||||
? std::array<ck_tile::index_t, 4>{batch, 1, shape_seqlen_q, shape_seqlen_k}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
|
||||
// require both seqlens_q and seqlens_kv for padding in bwd
|
||||
ck_tile::HostTensor<ck_tile::index_t> seqlen_q_host(std::array<ck_tile::index_t, 1>{batch});
|
||||
ck_tile::HostTensor<ck_tile::index_t> seqlen_kv_host(std::array<ck_tile::index_t, 1>{batch});
|
||||
|
||||
if(bias.type==bias_enum::elementwise_bias){
|
||||
// initialize the bias_host to all 0
|
||||
bias_host.SetZero();
|
||||
|
||||
// generate random seqlens for q and kv in each batch
|
||||
ck_tile::FillUniformDistributionIntegerValue<ck_tile::index_t>{1.f, max_seqlen_q-1.f, seed}(seqlen_q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<ck_tile::index_t>{1.f, max_seqlen_k-1.f, seed}(seqlen_kv_host);
|
||||
|
||||
// fill in -inf to bias_host each batch
|
||||
for(ck_tile::index_t b_i = 0; b_i < batch; b_i++){
|
||||
// debug printing
|
||||
std::cout<<"seqlen_q["<<b_i<<"]: "<<seqlen_q_host(b_i)<<std::endl;
|
||||
std::cout<<"seqlen_kv["<<b_i<<"]: "<<seqlen_kv_host(b_i)<<std::endl;
|
||||
|
||||
// bias are like
|
||||
// 0, 0, 0, ..., -inf, -inf, -inf
|
||||
// 0, 0, 0, ..., -inf, -inf, -inf
|
||||
// 0, 0, 0, ..., -inf, -inf, -inf
|
||||
// ...
|
||||
// -inf, -inf, -inf, ... -inf,
|
||||
// -inf, -inf, -inf, ... -inf,
|
||||
|
||||
// fill in the right part
|
||||
for(ck_tile::index_t s_kv_i = seqlen_kv_host(b_i); s_kv_i < max_seqlen_k; s_kv_i++){
|
||||
for(ck_tile::index_t s_q_j = 0; s_q_j < max_seqlen_q; s_q_j++){
|
||||
bias_host(b_i, 0, s_q_j, s_kv_i) = -std::numeric_limits<BiasDataType>::infinity();
|
||||
}
|
||||
}
|
||||
|
||||
// fill in the bottom part
|
||||
for(ck_tile::index_t s_q_i = seqlen_q_host(b_i); s_q_i < max_seqlen_q; s_q_i++){
|
||||
for(ck_tile::index_t s_kv_j = 0; s_kv_j < max_seqlen_k; s_kv_j++){
|
||||
bias_host(b_i, 0, s_q_i, s_kv_j) = -std::numeric_limits<BiasDataType>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ck_tile::HostTensor<AccDataType> alibi_slope_host(
|
||||
bias.type == bias_enum::alibi
|
||||
? (bias.rank_info == 0 ? std::array<ck_tile::index_t, 2>{1, nhead}
|
||||
@@ -342,7 +386,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
|
||||
if(bias.type!=bias_enum::elementwise_bias){
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
|
||||
}
|
||||
ck_tile::FillUniformDistributionIntegerValue<OGradDataType>{-2.f, 2.f, seed}(do_host);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
@@ -350,7 +396,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
if(bias.type!=bias_enum::elementwise_bias){
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<OGradDataType>{0.f, 1.f, seed}(do_host);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
@@ -358,7 +406,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillTrigValue<QDataType>{}(q_host);
|
||||
ck_tile::FillTrigValue<KDataType>{}(k_host);
|
||||
ck_tile::FillTrigValue<VDataType>{}(v_host);
|
||||
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
if(bias.type!=bias_enum::elementwise_bias){
|
||||
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
}
|
||||
ck_tile::FillTrigValue<OGradDataType>{}(do_host);
|
||||
}
|
||||
if(bias.type == bias_enum::alibi)
|
||||
@@ -453,9 +503,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
v3_bf16_cvt};
|
||||
auto fmha_args = [&]() {
|
||||
assert(nhead % nhead_k == 0);
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
|
||||
/// 'nhead_stride_bias' are 0.
|
||||
/// NOTE: we broadcast bias from [batch, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
/// seqlen_k] in this example, hence the 'nhead_stride_bias' is 0.
|
||||
// setup stride_* arguments
|
||||
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
@@ -484,7 +533,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
|
||||
const ck_tile::index_t batch_stride_v = (nhead_k * shape_seqlen_k * hdim_v);
|
||||
const ck_tile::index_t batch_stride_bias = 0;
|
||||
const ck_tile::index_t batch_stride_bias = max_seqlen_q*max_seqlen_k;
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
|
||||
@@ -673,10 +722,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// elementwise bias
|
||||
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
||||
// clang-format off
|
||||
if(i_perm)
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); });
|
||||
else
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); });
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(b, 0, i[1] + query_offset, i[2]); });
|
||||
// clang-format on
|
||||
|
||||
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
|
||||
|
||||
@@ -618,11 +618,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v)
|
||||
: get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_knew))
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
// use Bx1x1xS bias with 0/-inf for padding mask
|
||||
// bias_enum::elementwise_bias now serve as padding indicator
|
||||
ck_tile::HostTensor<BiasDataType> bias_host(
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
|
||||
? std::array<ck_tile::index_t, 4>{batch, 1, 1, shape_seqlen_k}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
|
||||
// seqlens_kv is enough for validation as seqlens_q is implicit there
|
||||
ck_tile::HostTensor<ck_tile::index_t> seqlen_kv_host(std::array<ck_tile::index_t, 1>{batch});
|
||||
if(bias.type==bias_enum::elementwise_bias){
|
||||
bias_host.SetZero();
|
||||
ck_tile::FillUniformDistributionIntegerValue<ck_tile::index_t>{1.f, max_seqlen_k-1.f, seed}(seqlen_kv_host);
|
||||
for(ck_tile::index_t b_i = 0; b_i < batch; b_i++){
|
||||
std::cout<<"seqlen_kv["<<b_i<<"]: "<<seqlen_kv_host(b_i)<<std::endl;
|
||||
for(ck_tile::index_t s_kv_i = seqlen_kv_host(b_i); s_kv_i < max_seqlen_k; s_kv_i++){
|
||||
bias_host(b_i, 0, 0, s_kv_i) = -std::numeric_limits<BiasDataType>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<SaccDataType> alibi_slope_host(
|
||||
bias.type == bias_enum::alibi
|
||||
? (bias.rank_info == 0 ? std::array<ck_tile::index_t, 2>{1, nhead}
|
||||
@@ -672,7 +687,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(knew_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
|
||||
if(bias.type!=bias_enum::elementwise_bias){
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
|
||||
}
|
||||
}
|
||||
else if(init_method == "ni")
|
||||
{
|
||||
@@ -681,7 +698,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(knew_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(vnew_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
|
||||
if(bias.type!=bias_enum::elementwise_bias){
|
||||
ck_tile::FillNormalDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
|
||||
}
|
||||
}
|
||||
else if(init_method == "uf" || init_method == "1")
|
||||
{
|
||||
@@ -690,7 +709,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(knew_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
if(bias.type!=bias_enum::elementwise_bias){
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
}
|
||||
}
|
||||
else if(init_method == "nf")
|
||||
{
|
||||
@@ -699,7 +720,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(knew_host);
|
||||
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(vnew_host);
|
||||
ck_tile::FillNormalDistribution<BiasDataType>{0.f, 3.f, seed}(bias_host);
|
||||
if(bias.type!=bias_enum::elementwise_bias){
|
||||
ck_tile::FillNormalDistribution<BiasDataType>{0.f, 3.f, seed}(bias_host);
|
||||
}
|
||||
}
|
||||
else if(init_method == "tf" || init_method == "2")
|
||||
{
|
||||
@@ -708,7 +731,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillTrigValue<KDataType>{}(knew_host);
|
||||
ck_tile::FillTrigValue<VDataType>{}(v_host);
|
||||
ck_tile::FillTrigValue<VDataType>{}(vnew_host);
|
||||
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
if(bias.type!=bias_enum::elementwise_bias){
|
||||
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
}
|
||||
}
|
||||
else if(init_method == "ufq" || init_method == "uf:q" ||
|
||||
init_method == "3") // suitable for fp8 quantization
|
||||
@@ -722,7 +747,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// bias_fp8 = qscale_bias * bias_fp32
|
||||
float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k);
|
||||
// Assume bias is in [-1.f, 1.f] in original fp32
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
|
||||
if(bias.type!=bias_enum::elementwise_bias){
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
|
||||
}
|
||||
}
|
||||
if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
@@ -864,9 +891,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
const auto init_args = [&, k_paddings_ = seqlen_kpads](auto& args) {
|
||||
assert(nhead % nhead_k == 0);
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
/// NOTE: we broadcast bias from [batch, 1, 1, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
|
||||
/// 'nhead_stride_bias' are 0.
|
||||
/// 'stride_bias' are 0.
|
||||
// setup stride_* arguments
|
||||
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
@@ -884,7 +911,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else
|
||||
return i_perm ? seqlen_knew : nhead_k * seqlen_knew;
|
||||
}();
|
||||
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
|
||||
const ck_tile::index_t stride_bias = 0;
|
||||
const ck_tile::index_t stride_randval = (max_seqlen_k);
|
||||
const ck_tile::index_t stride_o_acc = (hdim_v);
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
@@ -908,8 +935,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else
|
||||
return i_perm ? hdim_v * seqlen_knew : seqlen_knew;
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_bias = 0;
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
|
||||
@@ -925,7 +951,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
|
||||
: (nhead_k * hdim_v * shape_seqlen_k));
|
||||
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_bias = shape_seqlen_k;
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q);
|
||||
@@ -1378,12 +1404,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(bias.type == bias_enum::elementwise_bias)
|
||||
{
|
||||
// elementwise bias
|
||||
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
||||
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, 1, real_seqlen_k});
|
||||
// clang-format off
|
||||
if(i_perm)
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); });
|
||||
else
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); });
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(b_idx, 0, 0, i[2] + key_offset); });
|
||||
// clang-format on
|
||||
|
||||
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
|
||||
|
||||
Reference in New Issue
Block a user