From 7e0bed7a82eb37413490db35ca7e7df6dc46ccbd Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Mon, 3 Mar 2025 11:23:32 -0600 Subject: [PATCH] hack bias for padding --- example/ck_tile/01_fmha/fmha_bwd.cpp | 70 +++++++++++++++++++++++----- example/ck_tile/01_fmha/fmha_fwd.cpp | 59 ++++++++++++++++------- 2 files changed, 99 insertions(+), 30 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index 7f70befbb5..270ba9d803 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -304,10 +304,54 @@ bool run(const ck_tile::ArgParser& arg_parser) get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); ck_tile::HostTensor v_host( get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)); + // use Bx1xSxS with 0/-inf for padding mask + // bias_enum::elementwise_bias now serve as padding indicator ck_tile::HostTensor bias_host( bias.type == bias_enum::elementwise_bias - ? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k) + ? std::array{batch, 1, shape_seqlen_q, shape_seqlen_k} : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + + // require both seqlens_q and seqlens_kv for padding in bwd + ck_tile::HostTensor seqlen_q_host(std::array{batch}); + ck_tile::HostTensor seqlen_kv_host(std::array{batch}); + + if(bias.type==bias_enum::elementwise_bias){ + // initialize the bias_host to all 0 + bias_host.SetZero(); + + // generate random seqlens for q and kv in each batch + ck_tile::FillUniformDistributionIntegerValue{1.f, max_seqlen_q-1.f, seed}(seqlen_q_host); + ck_tile::FillUniformDistributionIntegerValue{1.f, max_seqlen_k-1.f, seed}(seqlen_kv_host); + + // fill in -inf to bias_host each batch + for(ck_tile::index_t b_i = 0; b_i < batch; b_i++){ + // debug printing + std::cout<<"seqlen_q["<::infinity(); + } + } + + // fill in the bottom part + for(ck_tile::index_t s_q_i = seqlen_q_host(b_i); s_q_i < max_seqlen_q; s_q_i++){ + for(ck_tile::index_t s_kv_j = 0; s_kv_j < max_seqlen_k; s_kv_j++){ + bias_host(b_i, 0, s_q_i, s_kv_j) = -std::numeric_limits::infinity(); + } + } + } + } ck_tile::HostTensor alibi_slope_host( bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? std::array{1, nhead} @@ -342,7 +386,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); + if(bias.type!=bias_enum::elementwise_bias){ + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); + } ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(do_host); } else if(init_method == 1) @@ -350,7 +396,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + if(bias.type!=bias_enum::elementwise_bias){ + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + } ck_tile::FillUniformDistribution{0.f, 1.f, seed}(do_host); } else if(init_method == 2) @@ -358,7 +406,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillTrigValue{}(q_host); ck_tile::FillTrigValue{}(k_host); ck_tile::FillTrigValue{}(v_host); - ck_tile::FillTrigValue{}(bias_host); + if(bias.type!=bias_enum::elementwise_bias){ + ck_tile::FillTrigValue{}(bias_host); + } ck_tile::FillTrigValue{}(do_host); } if(bias.type == bias_enum::alibi) @@ -453,9 +503,8 @@ bool run(const ck_tile::ArgParser& arg_parser) v3_bf16_cvt}; auto fmha_args = [&]() { assert(nhead % nhead_k == 0); - /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, - /// seqlen_k] in this example, hence both the 'batch_stride_bias' & - /// 'nhead_stride_bias' are 0. + /// NOTE: we broadcast bias from [batch, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, + /// seqlen_k] in this example, hence the 'nhead_stride_bias' is 0. // setup stride_* arguments const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); @@ -484,7 +533,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); const ck_tile::index_t batch_stride_v = (nhead_k * shape_seqlen_k * hdim_v); - const ck_tile::index_t batch_stride_bias = 0; + const ck_tile::index_t batch_stride_bias = max_seqlen_q*max_seqlen_k; const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v); @@ -673,10 +722,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // elementwise bias ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); // clang-format off - if(i_perm) - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); - else - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(b, 0, i[1] + query_offset, i[2]); }); // clang-format on // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index b3855e59df..ce32cf6e87 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -618,11 +618,26 @@ bool run(const ck_tile::ArgParser& arg_parser) ? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v) : get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_knew)) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + // use Bx1x1xS bias with 0/-inf for padding mask + // bias_enum::elementwise_bias now serve as padding indicator ck_tile::HostTensor bias_host( bias.type == bias_enum::elementwise_bias - ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) + ? std::array{batch, 1, 1, shape_seqlen_k} : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + // seqlens_kv is enough for validation as seqlens_q is implicit there + ck_tile::HostTensor seqlen_kv_host(std::array{batch}); + if(bias.type==bias_enum::elementwise_bias){ + bias_host.SetZero(); + ck_tile::FillUniformDistributionIntegerValue{1.f, max_seqlen_k-1.f, seed}(seqlen_kv_host); + for(ck_tile::index_t b_i = 0; b_i < batch; b_i++){ + std::cout<<"seqlen_kv["<::infinity(); + } + } + } + ck_tile::HostTensor alibi_slope_host( bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? std::array{1, nhead} @@ -672,7 +687,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(v_host); ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); + if(bias.type!=bias_enum::elementwise_bias){ + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); + } } else if(init_method == "ni") { @@ -681,7 +698,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(v_host); ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); + if(bias.type!=bias_enum::elementwise_bias){ + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); + } } else if(init_method == "uf" || init_method == "1") { @@ -690,7 +709,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistribution{0.f, 1.f, seed}(knew_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(vnew_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + if(bias.type!=bias_enum::elementwise_bias){ + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + } } else if(init_method == "nf") { @@ -699,7 +720,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillNormalDistribution{0.f, 3.f, seed}(knew_host); ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v_host); ck_tile::FillNormalDistribution{0.f, 3.f, seed}(vnew_host); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(bias_host); + if(bias.type!=bias_enum::elementwise_bias){ + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(bias_host); + } } else if(init_method == "tf" || init_method == "2") { @@ -708,7 +731,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillTrigValue{}(knew_host); ck_tile::FillTrigValue{}(v_host); ck_tile::FillTrigValue{}(vnew_host); - ck_tile::FillTrigValue{}(bias_host); + if(bias.type!=bias_enum::elementwise_bias){ + ck_tile::FillTrigValue{}(bias_host); + } } else if(init_method == "ufq" || init_method == "uf:q" || init_method == "3") // suitable for fp8 quantization @@ -722,7 +747,9 @@ bool run(const ck_tile::ArgParser& arg_parser) // bias_fp8 = qscale_bias * bias_fp32 float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k); // Assume bias is in [-1.f, 1.f] in original fp32 - ck_tile::FillUniformDistribution{-qscale_bias, qscale_bias, seed}(bias_host); + if(bias.type!=bias_enum::elementwise_bias){ + ck_tile::FillUniformDistribution{-qscale_bias, qscale_bias, seed}(bias_host); + } } if(bias.type == bias_enum::alibi) { @@ -864,9 +891,9 @@ bool run(const ck_tile::ArgParser& arg_parser) const auto init_args = [&, k_paddings_ = seqlen_kpads](auto& args) { assert(nhead % nhead_k == 0); - /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, + /// NOTE: we broadcast bias from [batch, 1, 1, seqlen_k] to [batch, nhead, seqlen_q, /// seqlen_k] in this example, hence both the 'batch_stride_bias' & - /// 'nhead_stride_bias' are 0. + /// 'stride_bias' are 0. // setup stride_* arguments const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); @@ -884,7 +911,7 @@ bool run(const ck_tile::ArgParser& arg_parser) else return i_perm ? seqlen_knew : nhead_k * seqlen_knew; }(); - const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); + const ck_tile::index_t stride_bias = 0; const ck_tile::index_t stride_randval = (max_seqlen_k); const ck_tile::index_t stride_o_acc = (hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); @@ -908,8 +935,7 @@ bool run(const ck_tile::ArgParser& arg_parser) else return i_perm ? hdim_v * seqlen_knew : seqlen_knew; }(); - const ck_tile::index_t nhead_stride_bias = - (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); + const ck_tile::index_t nhead_stride_bias = 0; const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); @@ -925,7 +951,7 @@ bool run(const ck_tile::ArgParser& arg_parser) (0 < page_block_size ? (nhead_k * hdim_v * page_block_size) : (nhead_k * hdim_v * shape_seqlen_k)); const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); - const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); + const ck_tile::index_t batch_stride_bias = shape_seqlen_k; const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q); @@ -1378,12 +1404,9 @@ bool run(const ck_tile::ArgParser& arg_parser) if(bias.type == bias_enum::elementwise_bias) { // elementwise bias - ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + ck_tile::HostTensor bias_host_ref({1, 1, real_seqlen_k}); // clang-format off - if(i_perm) - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); }); - else - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); }); + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(b_idx, 0, 0, i[2] + key_offset); }); // clang-format on // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,