mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 09:09:50 +00:00
add dry sampler (#513)
* add dry sampler * use vocab instead of model in dry_init function * fix compile error for build test --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -666,6 +666,47 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
sparams.top_n_sigma = std::stof(argv[i]);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (arg == "--dry-multiplier") {
|
||||
CHECK_ARG
|
||||
sparams.dry_multiplier = std::stof(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--dry-base") {
|
||||
CHECK_ARG
|
||||
sparams.dry_base = std::stof(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--dry-allowed-length") {
|
||||
CHECK_ARG
|
||||
sparams.dry_allowed_length = std::stof(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--dry-penalty-last-n") {
|
||||
CHECK_ARG
|
||||
sparams.dry_penalty_last_n = std::stof(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--dry-sequence-breaker") {
|
||||
CHECK_ARG
|
||||
static bool defaults_cleared = false;
|
||||
|
||||
if (!defaults_cleared) {
|
||||
params.sparams.dry_sequence_breakers.clear();
|
||||
defaults_cleared = true;
|
||||
}
|
||||
std::string value= std::string(argv[i]);
|
||||
if (value == "none") {
|
||||
params.sparams.dry_sequence_breakers.clear();
|
||||
}
|
||||
else {
|
||||
for (size_t i; i < value.size(); i++)
|
||||
{
|
||||
params.sparams.dry_sequence_breakers.emplace_back(""+value[i]);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (arg == "--cfg-negative-prompt") {
|
||||
CHECK_ARG
|
||||
sparams.cfg_negative_prompt = argv[i];
|
||||
@@ -2326,6 +2367,11 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
||||
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
||||
}
|
||||
|
||||
if (params.sparams.dry_penalty_last_n == -1) {
|
||||
LOG("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
params.sparams.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||
}
|
||||
|
||||
if (params.warmup) {
|
||||
LOG("warming up the model with an empty run\n");
|
||||
|
||||
@@ -3389,6 +3435,10 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
||||
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
|
||||
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
|
||||
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
||||
fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length);
|
||||
fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base);
|
||||
fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier);
|
||||
fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n);
|
||||
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
|
||||
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
|
||||
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
#define LLAMA_API_INTERNAL
|
||||
#include "sampling.h"
|
||||
#include "llama-vocab.h"
|
||||
#include <random>
|
||||
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params) {
|
||||
struct llama_sampling_context * result = new llama_sampling_context();
|
||||
|
||||
result->params = params;
|
||||
@@ -36,13 +37,32 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
|
||||
}
|
||||
result->grammar = grammar;
|
||||
}
|
||||
|
||||
result->prev.resize(params.n_prev);
|
||||
|
||||
result->n_valid = 0;
|
||||
|
||||
// init DRY
|
||||
for (const auto& cnstr : params.samplers_sequence)
|
||||
{
|
||||
switch (cnstr)
|
||||
{
|
||||
case llama_sampler_type::DRY:
|
||||
{
|
||||
std::vector<const char*> c_breakers;
|
||||
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||
for (const auto& str : params.dry_sequence_breakers)
|
||||
{
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
result->smpl=llama_sampler_init_dry(vocab, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size());
|
||||
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
llama_sampling_set_rng_seed(result, params.seed);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -50,7 +70,8 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
|
||||
if (ctx->grammar != NULL) {
|
||||
llama_grammar_free(ctx->grammar);
|
||||
}
|
||||
|
||||
if (ctx->smpl !=NULL)
|
||||
llama_sampler_dry_free(ctx->smpl);
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
@@ -75,6 +96,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
|
||||
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
||||
ctx->cur.clear();
|
||||
ctx->n_valid = 0;
|
||||
llama_sampler_dry_reset(ctx->smpl);
|
||||
}
|
||||
|
||||
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
|
||||
@@ -95,6 +117,7 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds
|
||||
}
|
||||
|
||||
dst->prev = src->prev;
|
||||
dst->smpl = llama_sampler_dry_clone(src->smpl);
|
||||
}
|
||||
|
||||
llama_token llama_sampling_last(llama_sampling_context * ctx) {
|
||||
@@ -149,6 +172,7 @@ std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
||||
|
||||
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
|
||||
switch (sampler_type) {
|
||||
case llama_sampler_type::DRY: return "dry";
|
||||
case llama_sampler_type::TOP_K: return "top_k";
|
||||
case llama_sampler_type::TFS_Z: return "tfs_z";
|
||||
case llama_sampler_type::TYPICAL_P: return "typical_p";
|
||||
@@ -163,6 +187,7 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
|
||||
|
||||
std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
||||
std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
|
||||
{"dry", llama_sampler_type::DRY},
|
||||
{"top_k", llama_sampler_type::TOP_K},
|
||||
{"top_p", llama_sampler_type::TOP_P},
|
||||
{"typical_p", llama_sampler_type::TYPICAL_P},
|
||||
@@ -176,6 +201,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
|
||||
// since samplers names are written multiple ways
|
||||
// make it ready for both system names and input names
|
||||
std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
|
||||
{"dry", llama_sampler_type::DRY},
|
||||
{"top-k", llama_sampler_type::TOP_K},
|
||||
{"top-p", llama_sampler_type::TOP_P},
|
||||
{"nucleus", llama_sampler_type::TOP_P},
|
||||
@@ -215,6 +241,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
|
||||
|
||||
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) {
|
||||
std::unordered_map<char, llama_sampler_type> sampler_name_map {
|
||||
{'d', llama_sampler_type::DRY},
|
||||
{'k', llama_sampler_type::TOP_K},
|
||||
{'p', llama_sampler_type::TOP_P},
|
||||
{'y', llama_sampler_type::TYPICAL_P},
|
||||
@@ -238,25 +265,28 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
|
||||
|
||||
// no reasons to expose this function in header
|
||||
static void sampler_queue(
|
||||
struct llama_context * ctx_main,
|
||||
const llama_sampling_params & params,
|
||||
llama_token_data_array & cur_p,
|
||||
size_t min_keep) {
|
||||
const float temp = params.temp;
|
||||
const float dynatemp_range = params.dynatemp_range;
|
||||
struct llama_context* ctx_main,
|
||||
const llama_sampling_params& params,
|
||||
llama_sampling_context * ctx_sampling,
|
||||
llama_token_data_array& cur_p,
|
||||
size_t min_keep) {
|
||||
const float temp = params.temp;
|
||||
const float dynatemp_range = params.dynatemp_range;
|
||||
const float dynatemp_exponent = params.dynatemp_exponent;
|
||||
const int32_t top_k = params.top_k;
|
||||
const float top_p = params.top_p;
|
||||
const float min_p = params.min_p;
|
||||
const float tfs_z = params.tfs_z;
|
||||
const float typical_p = params.typical_p;
|
||||
const float xtc_probability = params.xtc_probability;
|
||||
const float xtc_threshold = params.xtc_threshold;
|
||||
const float top_n_sigma = params.top_n_sigma;
|
||||
const int32_t top_k = params.top_k;
|
||||
const float top_p = params.top_p;
|
||||
const float min_p = params.min_p;
|
||||
const float tfs_z = params.tfs_z;
|
||||
const float typical_p = params.typical_p;
|
||||
const float xtc_probability = params.xtc_probability;
|
||||
const float xtc_threshold = params.xtc_threshold;
|
||||
const float top_n_sigma = params.top_n_sigma;
|
||||
|
||||
const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
|
||||
|
||||
for (auto sampler_type : samplers_sequence) {
|
||||
switch (sampler_type) {
|
||||
case llama_sampler_type::DRY : llama_sample_dry (ctx_main, ctx_sampling->smpl, &cur_p); break;
|
||||
case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
|
||||
case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
|
||||
case llama_sampler_type::TYPICAL_P : llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
|
||||
@@ -317,7 +347,7 @@ static llama_token llama_sampling_sample_impl(
|
||||
// temperature sampling
|
||||
size_t min_keep = std::max(1, params.min_keep);
|
||||
|
||||
sampler_queue(ctx_main, params, cur_p, min_keep);
|
||||
sampler_queue(ctx_main, params,ctx_sampling, cur_p, min_keep);
|
||||
|
||||
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
|
||||
|
||||
@@ -472,4 +502,5 @@ void llama_sampling_accept(
|
||||
if (ctx_sampling->grammar != NULL && apply_grammar) {
|
||||
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
|
||||
}
|
||||
llama_sampler_dry_accept(ctx_sampling->smpl, id);
|
||||
}
|
||||
|
||||
@@ -35,11 +35,16 @@ typedef struct llama_sampling_params {
|
||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||
float penalty_present = 0.00f; // 0.0 = disabled
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
||||
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
||||
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||
int32_t total_context_size = 16840;
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
float xtc_probability = 0.0f; // xtc probability
|
||||
@@ -48,12 +53,16 @@ typedef struct llama_sampling_params {
|
||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
||||
|
||||
std::vector<std::string> dry_sequence_breakers = { "\n", ":", "\"", "*" }; // default sequence breakers for DRY
|
||||
|
||||
std::vector<llama_sampler_type> samplers_sequence = {
|
||||
llama_sampler_type::DRY,
|
||||
llama_sampler_type::TOP_K,
|
||||
llama_sampler_type::TFS_Z,
|
||||
llama_sampler_type::TYPICAL_P,
|
||||
llama_sampler_type::TOP_P,
|
||||
llama_sampler_type::MIN_P,
|
||||
llama_sampler_type::XTC,
|
||||
llama_sampler_type::TOP_N_SIGMA,
|
||||
llama_sampler_type::TEMPERATURE
|
||||
};
|
||||
@@ -88,6 +97,8 @@ struct llama_sampling_context {
|
||||
// TODO: replace with ring-buffer
|
||||
std::vector<llama_token> prev;
|
||||
std::vector<llama_token_data> cur;
|
||||
llama_sampler_dry* smpl;
|
||||
|
||||
size_t n_valid; // Number of correct top tokens with correct probabilities.
|
||||
|
||||
std::mt19937 rng;
|
||||
@@ -96,7 +107,7 @@ struct llama_sampling_context {
|
||||
#include "common.h"
|
||||
|
||||
// Create a new sampling context instance.
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params);
|
||||
|
||||
void llama_sampling_free(struct llama_sampling_context * ctx);
|
||||
|
||||
|
||||
@@ -349,7 +349,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
std::vector<llama_token> embd;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), sparams);
|
||||
|
||||
while (n_remain != 0 || params.interactive) {
|
||||
// predict
|
||||
|
||||
@@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
|
||||
|
||||
LOG_TEE("\n");
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(ctx_llava->model),params->sparams);
|
||||
if (!ctx_sampling) {
|
||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
|
||||
@@ -218,7 +218,7 @@ static struct llama_sampling_context * llama_init(struct llava_context * ctx_lla
|
||||
|
||||
LOG_TEE("\n");
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(ctx_llava->model),params->sparams);
|
||||
return ctx_sampling;
|
||||
}
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ int main(int argc, char ** argv) {
|
||||
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
|
||||
|
||||
// target model sampling context
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), params.sparams);
|
||||
|
||||
// verification n-grams
|
||||
std::vector<ngram_data> ngrams_cur(G);
|
||||
|
||||
@@ -106,7 +106,7 @@ int main(int argc, char ** argv){
|
||||
|
||||
bool has_eos = false;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), params.sparams);
|
||||
|
||||
std::vector<llama_token> draft;
|
||||
|
||||
|
||||
@@ -531,7 +531,7 @@ int main(int argc, char ** argv) {
|
||||
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
|
||||
}
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), sparams);
|
||||
if (!ctx_sampling) {
|
||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
|
||||
@@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
|
||||
for (size_t i = 0; i < clients.size(); ++i) {
|
||||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.ctx_sampling = llama_sampling_init(params.sparams);
|
||||
client.ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), params.sparams);
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
set(TARGET rpc-server)
|
||||
add_executable(${TARGET} rpc-server.cpp)
|
||||
target_link_libraries(${TARGET} PRIVATE ggml)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
if (MSVC)
|
||||
target_link_options(${TARGET} PRIVATE
|
||||
$<$<CONFIG:DEBUG>:/STACK:20971520,1048576 >
|
||||
$<$<CONFIG:RELEASE>:/STACK:20971520,1048576>
|
||||
)
|
||||
endif()
|
||||
@@ -37,7 +37,13 @@ install(TARGETS ${TARGET} RUNTIME)
|
||||
target_compile_definitions(${TARGET} PRIVATE
|
||||
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
|
||||
)
|
||||
|
||||
if (MSVC)
|
||||
target_link_options(${TARGET} PRIVATE
|
||||
$<$<CONFIG:DEBUG>:/STACK:20971520,1048576 >
|
||||
$<$<CONFIG:RELEASE>:/STACK:20971520,1048576>
|
||||
)
|
||||
endif()
|
||||
# target_link_libraries(${TARGET} PRIVATE "/STACK:104857600")
|
||||
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
|
||||
@@ -977,6 +977,10 @@ struct server_context {
|
||||
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
||||
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
||||
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
||||
slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
|
||||
slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
|
||||
slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
|
||||
slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
|
||||
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
||||
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
@@ -987,6 +991,42 @@ struct server_context {
|
||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
|
||||
if (slot.sparams.penalty_last_n < -1) {
|
||||
throw std::runtime_error("Error: repeat_last_n must be >= -1");
|
||||
}
|
||||
|
||||
if (slot.sparams.dry_penalty_last_n < -1) {
|
||||
throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
|
||||
}
|
||||
|
||||
if (slot.sparams.penalty_last_n == -1) {
|
||||
// note: should be the slot's context and not the full context, but it's ok
|
||||
slot.sparams.penalty_last_n = llama_n_ctx(ctx);
|
||||
}
|
||||
|
||||
if (slot.sparams.dry_penalty_last_n == -1) {
|
||||
slot.sparams.dry_penalty_last_n = llama_n_ctx(ctx);
|
||||
|
||||
}
|
||||
if (slot.sparams.dry_base < 1.0f)
|
||||
{
|
||||
slot.sparams.dry_base = default_sparams.dry_base;
|
||||
}
|
||||
|
||||
// sequence breakers for DRY
|
||||
{
|
||||
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
|
||||
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
|
||||
|
||||
if (data.contains("dry_sequence_breakers")) {
|
||||
slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
|
||||
if (slot.sparams.dry_sequence_breakers.empty()) {
|
||||
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// process "json_schema" and "grammar"
|
||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
|
||||
@@ -1156,7 +1196,7 @@ struct server_context {
|
||||
if (slot.ctx_sampling != nullptr) {
|
||||
llama_sampling_free(slot.ctx_sampling);
|
||||
}
|
||||
slot.ctx_sampling = llama_sampling_init(slot.sparams);
|
||||
slot.ctx_sampling = llama_sampling_init(llama_get_model_vocab(model),slot.sparams);
|
||||
if (slot.ctx_sampling == nullptr) {
|
||||
// for now, the only error that may happen here is invalid grammar
|
||||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||
@@ -1405,6 +1445,11 @@ struct server_context {
|
||||
{"frequency_penalty", slot.sparams.penalty_freq},
|
||||
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
|
||||
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
|
||||
{"dry_multiplier", slot.sparams.dry_multiplier},
|
||||
{"dry_base", slot.sparams.dry_base},
|
||||
{"dry_allowed_length", slot.sparams.dry_allowed_length},
|
||||
{"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
|
||||
{"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
|
||||
{"mirostat", slot.sparams.mirostat},
|
||||
{"mirostat_tau", slot.sparams.mirostat_tau},
|
||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||
@@ -2337,6 +2382,13 @@ struct server_context {
|
||||
slot.command = SLOT_COMMAND_NONE;
|
||||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
llama_sampling_reset(slot.ctx_sampling);
|
||||
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
||||
llama_token id = slot.prompt_tokens[i];
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, id, false);
|
||||
}
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
@@ -179,7 +179,7 @@ int main(int argc, char ** argv) {
|
||||
bool has_eos = false;
|
||||
|
||||
// target model sampling context
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model_tgt), params.sparams);
|
||||
|
||||
// draft sequence data
|
||||
std::vector<seq_draft> drafts(n_seq_dft);
|
||||
@@ -190,7 +190,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
||||
drafts[s].ctx_sampling = llama_sampling_init(llama_get_model_vocab(model_dft), params.sparams);
|
||||
}
|
||||
|
||||
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
||||
|
||||
@@ -40,6 +40,8 @@
|
||||
|
||||
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
||||
|
||||
#define LLAMA_TOKEN_NULL -1
|
||||
|
||||
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
||||
@@ -556,6 +558,7 @@ extern "C" {
|
||||
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
||||
|
||||
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
||||
LLAMA_API const struct llama_vocab* llama_get_model_vocab(const struct llama_model* model);
|
||||
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
|
||||
@@ -1222,6 +1225,30 @@ extern "C" {
|
||||
llama_token_data_array * candidates_p,
|
||||
float top_n_sigma);
|
||||
|
||||
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
|
||||
LLAMA_API struct llama_sampler_dry * llama_sampler_init_dry(
|
||||
const struct llama_vocab* model,
|
||||
float dry_multiplier,
|
||||
float dry_base,
|
||||
int32_t dry_allowed_length,
|
||||
int32_t dry_penalty_last_n,
|
||||
const char** seq_breakers,
|
||||
size_t num_breakers);
|
||||
|
||||
//LLAMA_API void llama_sample_dry(struct llama_context* ctx, llama_token_data_array* candidates_p, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers);
|
||||
|
||||
void llama_sample_dry(struct llama_context* ctx, struct llama_sampler_dry* smpl, llama_token_data_array* candidates_p);
|
||||
|
||||
void llama_sampler_dry_reset(struct llama_sampler_dry* smpl);
|
||||
|
||||
void llama_sampler_dry_free(struct llama_sampler_dry* smpl);
|
||||
|
||||
struct llama_sampler_dry* llama_sampler_dry_clone(struct llama_sampler_dry* smpl);
|
||||
|
||||
void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token);
|
||||
|
||||
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
|
||||
|
||||
|
||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
|
||||
114
src/llama-impl.h
114
src/llama-impl.h
@@ -9,6 +9,7 @@
|
||||
|
||||
#define LLAMA_API_INTERNAL
|
||||
#include "llama.h"
|
||||
#include <stdexcept>
|
||||
|
||||
#ifdef __GNUC__
|
||||
#ifdef __MINGW32__
|
||||
@@ -20,6 +21,7 @@
|
||||
#define LLAMA_ATTRIBUTE_FORMAT(...)
|
||||
#endif
|
||||
|
||||
|
||||
//
|
||||
// logging
|
||||
//
|
||||
@@ -52,3 +54,115 @@ static void replace_all(std::string & s, const std::string & search, const std::
|
||||
builder.append(s, last_pos, std::string::npos);
|
||||
s = std::move(builder);
|
||||
}
|
||||
|
||||
|
||||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||
template<typename T>
|
||||
struct ring_buffer {
|
||||
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
|
||||
|
||||
T& front() {
|
||||
if (sz == 0) {
|
||||
throw std::runtime_error("ring buffer is empty");
|
||||
}
|
||||
return data[first];
|
||||
}
|
||||
|
||||
const T& front() const {
|
||||
if (sz == 0) {
|
||||
throw std::runtime_error("ring buffer is empty");
|
||||
}
|
||||
return data[first];
|
||||
}
|
||||
|
||||
T& back() {
|
||||
if (sz == 0) {
|
||||
throw std::runtime_error("ring buffer is empty");
|
||||
}
|
||||
return data[pos];
|
||||
}
|
||||
|
||||
const T& back() const {
|
||||
if (sz == 0) {
|
||||
throw std::runtime_error("ring buffer is empty");
|
||||
}
|
||||
return data[pos];
|
||||
}
|
||||
|
||||
void push_back(const T& value) {
|
||||
if (capacity == 0) {
|
||||
throw std::runtime_error("ring buffer: capacity is zero");
|
||||
}
|
||||
|
||||
if (sz == capacity) {
|
||||
// advance the start when buffer is full
|
||||
first = (first + 1) % capacity;
|
||||
}
|
||||
else {
|
||||
sz++;
|
||||
}
|
||||
data[pos] = value;
|
||||
pos = (pos + 1) % capacity;
|
||||
}
|
||||
|
||||
T pop_front() {
|
||||
if (sz == 0) {
|
||||
throw std::runtime_error("ring buffer is empty");
|
||||
}
|
||||
T value = data[first];
|
||||
first = (first + 1) % capacity;
|
||||
sz--;
|
||||
return value;
|
||||
}
|
||||
|
||||
//T & operator[](size_t i) {
|
||||
// if (i >= sz) {
|
||||
// throw std::runtime_error("ring buffer: index out of bounds");
|
||||
// }
|
||||
// return data[(first + i) % capacity];
|
||||
//}
|
||||
|
||||
//const T & at(size_t i) const {
|
||||
// if (i >= sz) {
|
||||
// throw std::runtime_error("ring buffer: index out of bounds");
|
||||
// }
|
||||
// return data[(first + i) % capacity];
|
||||
//}
|
||||
|
||||
const T& rat(size_t i) const {
|
||||
if (i >= sz) {
|
||||
throw std::runtime_error("ring buffer: index out of bounds");
|
||||
}
|
||||
return data[(first + sz - i - 1) % capacity];
|
||||
}
|
||||
|
||||
std::vector<T> to_vector() const {
|
||||
std::vector<T> result;
|
||||
result.reserve(sz);
|
||||
for (size_t i = 0; i < sz; i++) {
|
||||
result.push_back(data[(first + i) % capacity]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void clear() {
|
||||
// here only reset the status of the buffer
|
||||
sz = 0;
|
||||
first = 0;
|
||||
pos = 0;
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return sz == 0;
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return sz;
|
||||
}
|
||||
|
||||
size_t capacity = 0;
|
||||
size_t sz = 0;
|
||||
size_t first = 0;
|
||||
size_t pos = 0;
|
||||
std::vector<T> data;
|
||||
};
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
#include "llama-sampling.h"
|
||||
#include "llama-vocab.h"
|
||||
#include "llama-grammar.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
@@ -469,7 +471,7 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array
|
||||
}
|
||||
|
||||
void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma) {
|
||||
|
||||
|
||||
if (top_n_sigma <= 0.0f || candidates->size < 4) {
|
||||
// top_n_sigma <= 0: disabled
|
||||
// candidates->size < 4: no point in applying the transformation for fewer than 4 logits.
|
||||
@@ -725,3 +727,310 @@ llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama
|
||||
llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
|
||||
}
|
||||
|
||||
|
||||
// DRY
|
||||
|
||||
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
||||
static void get_overlapping_token_sequences(const llama_vocab& vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
|
||||
for (llama_token token_id = 0; token_id < (llama_token)vocab.n_tokens(); token_id++) {
|
||||
std::string word = llama_detokenize(vocab, { token_id }, true);
|
||||
if (word.find(str) != std::string::npos) {
|
||||
token_sequences.emplace(token_id, std::vector<llama_token>());
|
||||
}
|
||||
else {
|
||||
size_t word_len = word.size(), str_len = str.size();
|
||||
size_t pos = -1;
|
||||
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
|
||||
bool match = true;
|
||||
size_t i;
|
||||
for (i = 1; i < str_len && i + pos < word_len; ++i) {
|
||||
if (word[pos + i] != str[i]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
|
||||
if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
|
||||
tokenization.resize(max_tail_len);
|
||||
}
|
||||
|
||||
// Ensure we don't already have a duplicate matching tokenization
|
||||
auto its = token_sequences.equal_range(token_id);
|
||||
bool found = false;
|
||||
for (auto it = its.first; it != its.second; ++it) {
|
||||
if (tokenization == it->second) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
token_sequences.emplace(token_id, tokenization);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static const char* llama_sampler_dry_name(const struct llama_sampler* /*smpl*/) {
|
||||
return "dry";
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
||||
void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p) {
|
||||
if (smpl->dry_multiplier == 0.0f || smpl->dry_base < 1.0f || smpl->dry_penalty_last_n == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t effective_dry_penalty_last_n = (smpl->dry_penalty_last_n == -1) ? smpl->total_context_size : std::max(smpl->dry_penalty_last_n, 0);
|
||||
int last_n_repeat = std::min(std::min((int)smpl->last_tokens.size(), effective_dry_penalty_last_n), smpl->total_context_size);
|
||||
|
||||
if (last_n_repeat <= smpl->dry_allowed_length) {
|
||||
return;
|
||||
}
|
||||
|
||||
smpl->dry_repeat_count.assign(last_n_repeat, 0);
|
||||
smpl->dry_max_token_repeat.clear();
|
||||
|
||||
// Step 1: Look for restart sequences to limit the maximum repetition length.
|
||||
// Work backwards through the context looking for any token that begins a restart sequence.
|
||||
//
|
||||
// The collection `restart_sequences` is a mapping from a "head" token to all "tail"
|
||||
// sequences that together comprise a restart sequence. This allows us to quickly check
|
||||
// whether each token is the head of a complete sequence. Most restart sequences are actually
|
||||
// a single token, and for these the "tail" is an empty vector.
|
||||
//
|
||||
// If the token is a "head", test all restart sequences that begin with this token
|
||||
// (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
|
||||
// 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
|
||||
// longest matching sequence (if any) is used to limit the maximum repetition length.
|
||||
//
|
||||
// Note that in the case case of a short sequence contained in a longer one, this might fail to
|
||||
// find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
|
||||
// restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
|
||||
// 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
|
||||
//
|
||||
// This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
|
||||
// have already clamped the maximum tail sequence length when generating `restart_sequences`.
|
||||
// With clamping, this scan is O(N) in the context length.
|
||||
|
||||
int rep_limit = last_n_repeat;
|
||||
for (int i = 0; i < last_n_repeat; ++i) {
|
||||
llama_token token = smpl->last_tokens.rat(i);
|
||||
auto its = smpl->dry_processed_breakers.equal_range(token);
|
||||
if (its.first == smpl->dry_processed_breakers.end()) {
|
||||
continue;
|
||||
}
|
||||
int longest_match = -1;
|
||||
for (auto it = its.first; it != its.second; ++it) {
|
||||
// Note that (*it) does not contain the head character, so seq_len will be
|
||||
// the restart sequence length minus 1.
|
||||
// In the common case of a single-token restart sequence, (*it) will be empty
|
||||
// and we will trivially match.
|
||||
int seq_len = (int)it->second.size();
|
||||
if (seq_len > longest_match && seq_len <= (int)i) {
|
||||
bool match = true;
|
||||
for (int offset = 0; offset < seq_len; ++offset) {
|
||||
// The -1 when indexing `last_tokens` is because we already matched the head.
|
||||
if (it->second[offset] != smpl->last_tokens.rat(i - offset - 1)) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
longest_match = seq_len;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (longest_match >= 0) {
|
||||
// We found a restart sequence starting `i` tokens from the end and continuing for
|
||||
// `longest_match` tokens.
|
||||
rep_limit = i - longest_match;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (rep_limit < smpl->dry_allowed_length) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
|
||||
// the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
|
||||
// elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
|
||||
//
|
||||
// This algorithm is not currently documented on Wikipedia, but there is a clear description here:
|
||||
// https://ivanyu.me/blog/2014/10/15/z-algorithm/
|
||||
//
|
||||
// The code below is adapted from the public domain implementation by the same author here:
|
||||
// https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
|
||||
//
|
||||
// Example:
|
||||
// Last N tokens: a b c c b c y a b c
|
||||
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
|
||||
// ^
|
||||
// This `3` means that the last three tokens of the context (a b c) also appear here.
|
||||
//
|
||||
// This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
|
||||
// for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
|
||||
// repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
|
||||
// ensure that the inner while loops only examine each token in the context once as the outer
|
||||
// for loop iterates over the context.
|
||||
|
||||
{
|
||||
const int last = last_n_repeat - 1;
|
||||
int rt = 0, lt = 0;
|
||||
|
||||
for (int k = 1; k < last_n_repeat; ++k) {
|
||||
if (k > rt) {
|
||||
// If k is outside the current Z-box, do naive computation.
|
||||
int n = 0;
|
||||
while (n + k < last_n_repeat && smpl->last_tokens.rat(n) == smpl->last_tokens.rat(n + k)) {
|
||||
++n;
|
||||
}
|
||||
smpl->dry_repeat_count[last - k] = std::min(n, rep_limit);
|
||||
if (n > 0) {
|
||||
lt = k;
|
||||
rt = k + n - 1;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// If k is inside the current Z-box, consider two cases.
|
||||
|
||||
int p = k - lt; // Pair index.
|
||||
int right_part_len = rt - k + 1;
|
||||
|
||||
if (smpl->dry_repeat_count[last - p] < right_part_len) {
|
||||
int n = std::min(smpl->dry_repeat_count[last - p], rep_limit);
|
||||
smpl->dry_repeat_count[last - k] = n;
|
||||
}
|
||||
else {
|
||||
int i = rt + 1;
|
||||
while (i < last_n_repeat && smpl->last_tokens.rat(i) == smpl->last_tokens.rat(i - k)) {
|
||||
i += 1;
|
||||
}
|
||||
|
||||
int n = std::min(i - k, rep_limit);
|
||||
smpl->dry_repeat_count[last - k] = n;
|
||||
lt = k;
|
||||
rt = i - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
|
||||
// that would be generated by emitting each new token that would extend a sequence.
|
||||
//
|
||||
// Following the same example as above:
|
||||
// Last N tokens: a b c c b c y a b c
|
||||
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
|
||||
//
|
||||
// For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
|
||||
// c: 3 -> 4 (from `a b c` to `a b c c`)
|
||||
// b: 1 -> 2 (from `c` to `c b`)
|
||||
// y: 2 -> 3 (from `b c` to `b c y`)
|
||||
|
||||
for (int i = 0; i < last_n_repeat - 1; ++i) {
|
||||
int repeat_len = smpl->dry_repeat_count[i];
|
||||
if (repeat_len >= smpl->dry_allowed_length) {
|
||||
// This token ends a repeat, so the next token would continue one.
|
||||
// By convention, the value of `repeat_len` only includes the tokens currently
|
||||
// in the context, not the new token that would be added.
|
||||
llama_token token = smpl->last_tokens.rat(last_n_repeat - 2 - i);
|
||||
// Track the maximum sequence ending in this token.
|
||||
const auto& it = smpl->dry_max_token_repeat.find(token);
|
||||
if (it == smpl->dry_max_token_repeat.end() || it->second < repeat_len) {
|
||||
smpl->dry_max_token_repeat[token] = repeat_len;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
|
||||
|
||||
// Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
|
||||
// Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
|
||||
const float FLOAT_MAX_LOG = 88.7228391f;
|
||||
int max_exponent = 0;
|
||||
if (smpl->dry_base > 1.000001f) {
|
||||
max_exponent = FLOAT_MAX_LOG / std::log(smpl->dry_base);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
const auto& af_kvp = smpl->dry_max_token_repeat.find(cur_p->data[i].id);
|
||||
if (af_kvp != smpl->dry_max_token_repeat.end()) {
|
||||
// Check all sequence breakers starting with this token
|
||||
auto range = smpl->dry_processed_breakers.equal_range(cur_p->data[i].id);
|
||||
bool is_single_token_breaker = false;
|
||||
|
||||
for (auto it = range.first; it != range.second; ++it) {
|
||||
if (it->second.empty()) {
|
||||
is_single_token_breaker = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Apply penalty only if it's not a single-token sequence breaker
|
||||
if (!is_single_token_breaker) {
|
||||
int repeat_exp = af_kvp->second - smpl->dry_allowed_length;
|
||||
if (max_exponent > 0 && repeat_exp > max_exponent) {
|
||||
repeat_exp = max_exponent;
|
||||
}
|
||||
float penalty = smpl->dry_multiplier * std::pow(smpl->dry_base, repeat_exp);
|
||||
cur_p->data[i].logit -= penalty;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cur_p->sorted = false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab& vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
||||
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
|
||||
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
|
||||
const int MAX_CHAR_LEN = 40;
|
||||
const int MAX_SEQ_LEN = 20;
|
||||
|
||||
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
|
||||
|
||||
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
|
||||
// Process sequence breakers
|
||||
for (size_t i = 0; i < num_breakers; ++i) {
|
||||
if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
|
||||
LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string sequence_break(seq_breakers[i]);
|
||||
if (sequence_break.empty()) {
|
||||
LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (sequence_break.size() > MAX_CHAR_LEN) {
|
||||
LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
|
||||
sequence_break.resize(MAX_CHAR_LEN);
|
||||
}
|
||||
|
||||
get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
|
||||
}
|
||||
}
|
||||
|
||||
return new llama_sampler_dry {
|
||||
/* .total_context_size = */ context_size,
|
||||
/* .dry_multiplier = */ dry_multiplier,
|
||||
/* .dry_base = */ dry_base,
|
||||
/* .dry_allowed_length = */ dry_allowed_length,
|
||||
/* .dry_penalty_last_n = */ dry_penalty_last_n,
|
||||
/* .dry_processed_breakers = */ std::move(processed_breakers),
|
||||
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
|
||||
/* .dry_max_token_repeat = */ {},
|
||||
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama-impl.h"
|
||||
|
||||
#include <unordered_map>
|
||||
struct llama_sampling {
|
||||
llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
|
||||
|
||||
@@ -35,6 +35,34 @@ void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_
|
||||
void llama_sample_xtc_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep);
|
||||
void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma);
|
||||
|
||||
struct llama_sampler_dry {
|
||||
int32_t total_context_size;
|
||||
|
||||
const float dry_multiplier;
|
||||
const float dry_base;
|
||||
const int32_t dry_allowed_length;
|
||||
const int32_t dry_penalty_last_n;
|
||||
|
||||
std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
|
||||
std::vector<int> dry_repeat_count;
|
||||
std::unordered_map<llama_token, int> dry_max_token_repeat;
|
||||
ring_buffer<llama_token> last_tokens;
|
||||
};
|
||||
|
||||
struct llama_sampler_dry * llama_sampler_init_dry_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
int32_t context_size,
|
||||
float dry_multiplier,
|
||||
float dry_base,
|
||||
int32_t dry_allowed_length,
|
||||
int32_t dry_penalty_last_n,
|
||||
const char ** seq_breakers,
|
||||
size_t num_breakers);
|
||||
|
||||
void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p);
|
||||
|
||||
|
||||
|
||||
void llama_sample_repetition_penalties_impl(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates,
|
||||
@@ -56,3 +84,5 @@ llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, ll
|
||||
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
|
||||
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -75,6 +75,9 @@ struct naive_trie {
|
||||
llama_token value;
|
||||
};
|
||||
|
||||
uint32_t llama_vocab::n_tokens() const {
|
||||
return (uint32_t)id_to_token.size();
|
||||
}
|
||||
//
|
||||
// impl
|
||||
//
|
||||
@@ -1741,3 +1744,19 @@ int32_t llama_detokenize_impl(
|
||||
|
||||
return total <= text_len_max ? total : -total;
|
||||
}
|
||||
|
||||
std::string llama_detokenize(const struct llama_vocab& vocab, const std::vector<llama_token>& tokens, bool special) {
|
||||
std::string text;
|
||||
text.resize(std::max(text.capacity(), tokens.size()));
|
||||
int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
|
||||
if (n_chars < 0) {
|
||||
text.resize(-n_chars);
|
||||
n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
|
||||
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
|
||||
}
|
||||
|
||||
text.resize(n_chars);
|
||||
|
||||
// NOTE: the original tokenizer decodes bytes after collecting the pieces.
|
||||
return text;
|
||||
}
|
||||
|
||||
@@ -23,6 +23,8 @@ struct llama_vocab {
|
||||
|
||||
int max_token_len = 0; // used for optimizing longest token search
|
||||
|
||||
uint32_t n_tokens() const;
|
||||
|
||||
std::unordered_map<token, id> token_to_id;
|
||||
std::vector<token_data> id_to_token;
|
||||
|
||||
@@ -130,3 +132,8 @@ int32_t llama_detokenize_impl(
|
||||
int32_t text_len_max,
|
||||
bool remove_special,
|
||||
bool unparse_special);
|
||||
|
||||
std::string llama_detokenize(
|
||||
const struct llama_vocab& vocab,
|
||||
const std::vector<llama_token>& tokens,
|
||||
bool special);
|
||||
|
||||
@@ -20849,6 +20849,10 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
|
||||
return model->vocab.type;
|
||||
}
|
||||
|
||||
const struct llama_vocab* llama_get_model_vocab(const struct llama_model* model) {
|
||||
return &model->vocab;
|
||||
}
|
||||
|
||||
enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
switch (model->arch) {
|
||||
// these models do not use RoPE
|
||||
@@ -23280,6 +23284,11 @@ void llama_sample_top_n_sigma(struct llama_context * ctx, llama_token_data_array
|
||||
llama_sample_top_n_sigma_impl(ctx ? &ctx->sampling : nullptr, candidates_p, top_n_sigma);
|
||||
}
|
||||
|
||||
|
||||
void llama_sample_dry(struct llama_context* ctx, struct llama_sampler_dry* smpl, llama_token_data_array* candidates_p) {
|
||||
llama_sampler_dry_apply(smpl, candidates_p);
|
||||
}
|
||||
|
||||
void llama_sample_repetition_penalties(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
@@ -23327,6 +23336,42 @@ int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix,
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct llama_sampler_dry * llama_sampler_init_dry(const struct llama_vocab* vocab, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
||||
return llama_sampler_init_dry_impl(*vocab, vocab->n_tokens(), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers);
|
||||
}
|
||||
|
||||
void llama_sampler_dry_reset(struct llama_sampler_dry* smpl) {
|
||||
smpl->last_tokens.clear();
|
||||
smpl->dry_repeat_count.clear();
|
||||
smpl->dry_max_token_repeat.clear();
|
||||
}
|
||||
|
||||
void llama_sampler_dry_free(struct llama_sampler_dry* smpl) {
|
||||
delete smpl;
|
||||
}
|
||||
|
||||
struct llama_sampler_dry* llama_sampler_dry_clone(struct llama_sampler_dry* smpl) {
|
||||
// nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying
|
||||
auto* result = llama_sampler_init_dry(nullptr, smpl->dry_multiplier, smpl->dry_base, smpl->dry_allowed_length, smpl->dry_penalty_last_n, NULL, 0);
|
||||
// Copy the state, including the processed breakers
|
||||
{
|
||||
auto* result_ctx = smpl;
|
||||
result_ctx->dry_processed_breakers = smpl->dry_processed_breakers;
|
||||
result_ctx->dry_repeat_count = smpl->dry_repeat_count;
|
||||
result_ctx->dry_max_token_repeat = smpl->dry_max_token_repeat;
|
||||
result_ctx->last_tokens = smpl->last_tokens;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token) {
|
||||
if (smpl->dry_multiplier == 0.0f || smpl->dry_base < 1.0f || smpl->dry_penalty_last_n == 0) {
|
||||
return;
|
||||
}
|
||||
smpl->last_tokens.push_back(token);
|
||||
}
|
||||
|
||||
int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) {
|
||||
std::string str_split_path(split_path);
|
||||
char postfix[32];
|
||||
|
||||
Reference in New Issue
Block a user