mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK_TILE] support group from cmdline (#1295)
* support cmdline seqlen decode * silent print * update readme * update kernel launch 3d * update tile partitioner * fix spill for bf16 * modify based on comment * modify payload_t * fix bug for alibi mode * fix alibi test err * refactor kernel launch, support select timer * add missing file * remove useless code * add some comments
This commit is contained in:
@@ -34,6 +34,7 @@ args:
|
||||
if not equal to h, then this is GQA/MQA case
|
||||
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
|
||||
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
|
||||
also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
|
||||
-s_k seqlen_k, -1 means equal to s (default:-1)
|
||||
-d head dim for q, k (default:128)
|
||||
-d_v head dim for v, -1 means equal to d (default:-1)
|
||||
|
||||
@@ -44,11 +44,18 @@ auto create_args(int argc, char* argv[])
|
||||
"-1",
|
||||
"num of head, for k/v, -1 means equal to h\n"
|
||||
"if not equal to h, then this is GQA/MQA case")
|
||||
.insert("s",
|
||||
"3328",
|
||||
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary")
|
||||
.insert(
|
||||
"s",
|
||||
"3328",
|
||||
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
|
||||
"also with \"-s=s0,s1,s2...\" comma seperated int to set per batch seqlen(group-mode)")
|
||||
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
|
||||
.insert("s_kpad",
|
||||
"-1",
|
||||
"seqlen_k stride between 2 tokens, currently used in group-mode only\n"
|
||||
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
|
||||
"along seqlen, instead of packed. same as xformer kv_padding")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("scale_s",
|
||||
@@ -103,6 +110,7 @@ auto create_args(int argc, char* argv[])
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
@@ -177,10 +185,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return false;
|
||||
}
|
||||
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
if(seqlen_k < 0)
|
||||
seqlen_k = seqlen_q;
|
||||
auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode,
|
||||
batch,
|
||||
arg_parser.get_str("s"),
|
||||
arg_parser.get_str("s_k"),
|
||||
arg_parser.get_str("s_kpad"));
|
||||
|
||||
#if 0
|
||||
// clang-format off
|
||||
std::cout << "seqlen_qs:"; for(auto xx : seqlen_qs) { std::cout << xx << ","; } std::cout << std::endl;
|
||||
std::cout << "seqlen_ks:"; for(auto xx : seqlen_ks) { std::cout << xx << ","; } std::cout << std::endl;
|
||||
std::cout << "seqlen_kpads:"; for(auto xx : seqlen_kpads) { std::cout << xx << ","; } std::cout << std::endl;
|
||||
// clang-format on
|
||||
#endif
|
||||
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
if(hdim_v < 0)
|
||||
@@ -229,7 +247,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
bool lse = arg_parser.get_bool("lse");
|
||||
|
||||
bias_info bias = bias_info::decode(arg_parser.get_str("bias"));
|
||||
mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k);
|
||||
mask_info mask = mask_info::decode(
|
||||
arg_parser.get_str("mask"), seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore
|
||||
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
std::optional<uint32_t> seed = arg_parser.get_uint32("seed");
|
||||
@@ -242,11 +261,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
|
||||
ck_tile::stream_config stream_config{
|
||||
nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat};
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/* log_level = */ (kname ? 1 : 0),
|
||||
stream_warmup,
|
||||
stream_repeat,
|
||||
arg_parser.get_str("timer") == std::string("gpu")};
|
||||
|
||||
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q);
|
||||
const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k);
|
||||
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
|
||||
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
|
||||
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
|
||||
|
||||
using TypeConfig = FmhaFwdTypeConfig<DataType>;
|
||||
|
||||
@@ -302,9 +326,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// host memory for storing all the tensor elements
|
||||
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
|
||||
const ck_tile::index_t shape_seqlen_q =
|
||||
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
|
||||
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
|
||||
const ck_tile::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
|
||||
(mode == mode_enum::batch ? seqlen_ks[0]
|
||||
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
|
||||
: seqstart_k_with_padding_host.back()));
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
@@ -407,6 +433,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
@@ -414,7 +441,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
v_buf.ToDevice(v_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
|
||||
: seqstart_k_with_padding_host.data());
|
||||
seqlen_k_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr : seqlen_ks.data());
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
|
||||
// clang-format off
|
||||
@@ -430,7 +459,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const std::string prec = arg_parser.get_str("prec");
|
||||
|
||||
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
|
||||
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
|
||||
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] << "/" << seqlen_ks[0]
|
||||
<< (seqlen_kpads[0] < 0 ? ""
|
||||
: (std::string("(") + std::to_string(seqlen_kpads[0]) + ")"))
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
|
||||
<< ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout
|
||||
<< std::flush;
|
||||
@@ -460,7 +491,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
auto fmha_args = [&]() {
|
||||
auto fmha_args = [&, k_paddings_ = seqlen_kpads]() {
|
||||
assert(nhead % nhead_k == 0);
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
|
||||
@@ -506,7 +537,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
o_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(),
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
@@ -576,7 +607,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// adjust matrix index according to the mode
|
||||
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
|
||||
const ck_tile::index_t key_offset =
|
||||
(mode == mode_enum::batch
|
||||
? 0
|
||||
: (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb]));
|
||||
|
||||
const auto v_host_ref_lengths =
|
||||
std::array<ck_tile::index_t, 3>{nhead, hdim_v, real_seqlen_k};
|
||||
@@ -661,7 +695,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else
|
||||
{
|
||||
return ck_tile::Alibi<SaccDataType, true>{
|
||||
0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::VERTICAL};
|
||||
0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT};
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -671,7 +705,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
for(auto i_h = 0; i_h < nhead; i_h++)
|
||||
{
|
||||
SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h);
|
||||
alibi_host.slope = current_slope;
|
||||
alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope
|
||||
: -current_slope;
|
||||
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
|
||||
{
|
||||
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
|
||||
|
||||
@@ -78,6 +78,11 @@ BOOL_MAP = {
|
||||
"f" : "false"
|
||||
}
|
||||
|
||||
TILE_PARTITIONER_MAP = {
|
||||
"shb" : "ck_tile::FmhaFwdTilePartitioner_SHB",
|
||||
"hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS",
|
||||
}
|
||||
|
||||
DIRECTIONS = ["fwd"]
|
||||
GEN_DIR = "" # in Cmake, have to generate files in same folder
|
||||
|
||||
@@ -107,7 +112,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_dvpad},
|
||||
{F_bias},
|
||||
{F_lse},
|
||||
{F_squant},
|
||||
{F_squant},
|
||||
{F_occupancy}>;
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
@@ -136,7 +141,7 @@ using fmha_epilogue_{F_idx} =
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner<fmha_shape_{F_idx}>,
|
||||
ck_tile::FmhaFwdKernel<{F_tile_partitioner}<fmha_shape_{F_idx}>,
|
||||
fmha_pipeline_{F_idx},
|
||||
fmha_epilogue_{F_idx}>;
|
||||
|
||||
@@ -154,7 +159,7 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel<blocks.x, kBlockPerCu>(s, k_{{}}, grids, blocks, 0, kargs);
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
@@ -389,6 +394,12 @@ class FmhaFwdKernel:
|
||||
F_pipeline : FmhaFwdPipeline
|
||||
mask_impl : str
|
||||
|
||||
def get_tp(self) -> str:
|
||||
if self.F_mode == 'group':
|
||||
return 'hbs'
|
||||
else:
|
||||
return 'shb'
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
@@ -413,7 +424,7 @@ class FmhaFwdKernel:
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
@@ -421,12 +432,13 @@ class FmhaFwdKernel:
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag])
|
||||
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" +\
|
||||
return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
|
||||
@property
|
||||
|
||||
@@ -28,6 +28,7 @@ $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias
|
||||
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
@@ -4,12 +4,14 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core/container/span.hpp"
|
||||
|
||||
@@ -37,12 +39,14 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
|
||||
|
||||
std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
unsigned count,
|
||||
int32_t seqlens_sum,
|
||||
int32_t seqlen_avg,
|
||||
int32_t seqlen_max = -1, // if not negative, clamp max
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
assert(0 < count);
|
||||
|
||||
std::vector<int32_t> seqlens(count, seqlens_sum);
|
||||
std::vector<int32_t> seqlens(
|
||||
count, seqlen_max > 0 ? (seqlen_avg < seqlen_max ? seqlen_avg : seqlen_max) : seqlen_avg);
|
||||
|
||||
if(mode == mode_enum::group && 1 < count)
|
||||
{
|
||||
@@ -55,7 +59,7 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
std::uniform_int_distribution<size_type> step_dist(1, count - 1);
|
||||
auto next_step = std::bind(step_dist, std::ref(random_engine));
|
||||
|
||||
for(unsigned repeat = seqlens_sum * (count / 2); 0 < repeat; --repeat)
|
||||
for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
|
||||
{
|
||||
const size_type to_decrease = next_idx();
|
||||
// make sure each elements of seqlens is always greater than 0
|
||||
@@ -66,6 +70,11 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
|
||||
const size_type to_increase = (to_decrease + next_step()) % count;
|
||||
|
||||
if(seqlen_max > 0 && seqlens[to_increase] >= seqlen_max)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
--seqlens[to_decrease];
|
||||
++seqlens[to_increase];
|
||||
}
|
||||
@@ -76,10 +85,91 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
|
||||
std::vector<int32_t> generate_seqstarts(mode_enum mode,
|
||||
unsigned count,
|
||||
int32_t seqlens_sum,
|
||||
int32_t seqlen_avg,
|
||||
int32_t seqlen_max = -1,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
return to_seqstarts(generate_seqlens(mode, count, seqlens_sum, seed));
|
||||
return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_max, seed));
|
||||
}
|
||||
|
||||
/*
|
||||
* decode the seqlen string from cmdline
|
||||
* example (assume batch=3)
|
||||
* q_val=1,2,3 k_val=4,5,6 -> OK
|
||||
* q_val=1,2,3 -> OK, k same as q
|
||||
* q_val=1,2 -> OK, q will rand remaining 1 element, k same as q
|
||||
* q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element
|
||||
* q_val=1,2,3,4 -> OK, but ignore exceed one
|
||||
*
|
||||
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
|
||||
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
|
||||
*/
|
||||
std::tuple<std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>>
|
||||
decode_seqlen(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
std::string q_val,
|
||||
std::string k_val,
|
||||
std::string k_pad_val,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
|
||||
if(mode == mode_enum::batch)
|
||||
{
|
||||
ck_tile::index_t q = _S2I_(q_val);
|
||||
ck_tile::index_t k = _S2I_(k_val);
|
||||
auto s_q = std::vector<ck_tile::index_t>(batch, q);
|
||||
auto s_k = std::vector<ck_tile::index_t>(batch, k < 0 ? q : k);
|
||||
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
|
||||
return std::make_tuple(s_q, s_k, s_kpad);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::index_t idx = 0;
|
||||
std::string::size_type pos_q = 0;
|
||||
std::string::size_type pos_k = 0;
|
||||
std::string::size_type pos_kp = 0;
|
||||
std::vector<ck_tile::index_t> s_q;
|
||||
std::vector<ck_tile::index_t> s_k;
|
||||
std::vector<ck_tile::index_t> s_kpad;
|
||||
while(true)
|
||||
{
|
||||
auto found_q = q_val.find(',', pos_q);
|
||||
auto found_k = k_val.find(',', pos_k);
|
||||
auto found_kp = k_pad_val.find(',', pos_kp);
|
||||
|
||||
ck_tile::index_t q = _S2I_(
|
||||
q_val.substr(pos_q, found_q == std::string::npos ? found_q : found_q - pos_q));
|
||||
ck_tile::index_t k = _S2I_(
|
||||
k_val.substr(pos_k, found_k == std::string::npos ? found_k : found_k - pos_k));
|
||||
ck_tile::index_t kp = _S2I_(k_pad_val.substr(
|
||||
pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp));
|
||||
|
||||
s_q.push_back(q);
|
||||
s_k.push_back(k < 0 ? q : k);
|
||||
s_kpad.push_back(kp);
|
||||
idx++;
|
||||
if(found_q == std::string::npos || idx >= batch)
|
||||
{
|
||||
break;
|
||||
}
|
||||
pos_q = found_q + 1;
|
||||
pos_k = found_k == std::string::npos ? pos_k : found_k + 1;
|
||||
pos_kp = found_kp == std::string::npos ? pos_kp : found_kp + 1;
|
||||
}
|
||||
if(idx < batch)
|
||||
{
|
||||
auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), s_kpad.back(), seed);
|
||||
auto rem_k = generate_seqlens(mode, batch - idx, s_k.back(), s_kpad.back(), seed);
|
||||
|
||||
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
|
||||
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
|
||||
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
|
||||
}
|
||||
return std::make_tuple(s_q, s_k, s_kpad);
|
||||
}
|
||||
#undef _S2I_
|
||||
}
|
||||
|
||||
int env_get_int(const char* var_name, int default_int)
|
||||
@@ -87,6 +177,6 @@ int env_get_int(const char* var_name, int default_int)
|
||||
char* v = getenv(var_name);
|
||||
int r = default_int;
|
||||
if(v)
|
||||
r = atoi(v);
|
||||
r = std::atoi(v);
|
||||
return r;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user