mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 06:19:00 +00:00
Kalomaze's Quadratic Sampling
Quadratic Sampling
This commit is contained in:
@@ -36,6 +36,7 @@ parser.add_argument("-bn", "--botname", type = str, default = "Chatbort", help =
|
||||
parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt")
|
||||
|
||||
parser.add_argument("-temp", "--temperature", type = float, default = 0.95, help = "Sampler temperature, default = 0.95 (1 to disable)")
|
||||
parser.add_argument("-smthfctr", "--smoothing_factor", type = float, default = 0.0, help = "Smoothing Factor, default = 0.0 (0 to disable")
|
||||
parser.add_argument("-dyntemp", "--dynamic_temperature", type = str, help = "Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1")
|
||||
parser.add_argument("-topk", "--top_k", type = int, default = 50, help = "Sampler top-K, default = 50 (0 to disable)")
|
||||
parser.add_argument("-topp", "--top_p", type = float, default = 0.8, help = "Sampler top-P, default = 0.8 (0 to disable)")
|
||||
@@ -196,6 +197,7 @@ settings.typical = args.typical
|
||||
settings.token_repetition_penalty = args.repetition_penalty
|
||||
settings.token_frequency_penalty = args.frequency_penalty
|
||||
settings.token_presence_penalty = args.presence_penalty
|
||||
settings.smoothing_factor = args.smoothing_factor
|
||||
|
||||
if args.dynamic_temperature:
|
||||
dt_args = [float(alloc) for alloc in args.dynamic_temperature.split(",")]
|
||||
|
||||
@@ -99,15 +99,37 @@ void softmax_cpu
|
||||
(
|
||||
const int vocab_size,
|
||||
const float temperature,
|
||||
const float* logits,
|
||||
float* logits,
|
||||
const bool* logits_filter,
|
||||
float* output
|
||||
float* output,
|
||||
float smoothing_factor
|
||||
)
|
||||
{
|
||||
float esum = 0.0f;
|
||||
float itemp = 1.0f / temperature;
|
||||
float maxl = -1e38;
|
||||
|
||||
// Apply the quadratic transformation to the logits
|
||||
if (smoothing_factor > 0.0f)
|
||||
{
|
||||
// Calculate maxl as the maximum logit value
|
||||
float maxl = -1e38;
|
||||
for (int i = 0; i < vocab_size; i++)
|
||||
{
|
||||
if (!logits_filter[i]) continue;
|
||||
maxl = fmaxf(logits[i], maxl);
|
||||
}
|
||||
|
||||
for (int i = 0; i < vocab_size; i++)
|
||||
{
|
||||
if (!logits_filter[i]) continue;
|
||||
float logit_shifted = logits[i] - maxl;
|
||||
logits[i] = -smoothing_factor * logit_shifted * logit_shifted + maxl;
|
||||
// Limit the range of logits to prevent extreme values
|
||||
logits[i] = fminf(fmaxf(logits[i], -1e20), 1e20);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll(32)
|
||||
for (int i = 0; i < vocab_size; i++)
|
||||
{
|
||||
|
||||
@@ -23,9 +23,10 @@ void softmax_cpu
|
||||
(
|
||||
const int vocab_size,
|
||||
const float temperature,
|
||||
const float* logits,
|
||||
float* logits,
|
||||
const bool* logits_filter,
|
||||
float* output
|
||||
float* output,
|
||||
float smoothing_factor
|
||||
);
|
||||
|
||||
void normalize_cpu
|
||||
|
||||
@@ -987,7 +987,8 @@ std::vector<float> sample_basic
|
||||
float post_temperature,
|
||||
float min_temp = 0,
|
||||
float max_temp = 0.0f,
|
||||
float temp_exponent = 1.0f
|
||||
float temp_exponent = 1.0f,
|
||||
float smoothing_factor = 0.0f
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(logits, kFloat);
|
||||
@@ -1025,7 +1026,8 @@ std::vector<float> sample_basic
|
||||
temperature,
|
||||
logits_ptr + i * vocab_size,
|
||||
logits_filter_ptr + i * vocab_size,
|
||||
temp_probs
|
||||
temp_probs,
|
||||
smoothing_factor
|
||||
);
|
||||
|
||||
if (top_k == 1)
|
||||
|
||||
@@ -15,6 +15,7 @@ class ExLlamaV2Sampler:
|
||||
token_presence_penalty = 0.0
|
||||
|
||||
temperature = 0.8
|
||||
smoothing_factor = 0.0
|
||||
min_temp = 0
|
||||
max_temp = 0.0
|
||||
temp_exponent = 1.0
|
||||
@@ -50,6 +51,7 @@ class ExLlamaV2Sampler:
|
||||
c.token_presence_penalty = self.token_presence_penalty
|
||||
|
||||
c.temperature = self.temperature
|
||||
c.smoothing_factor = self.smoothing_factor
|
||||
c.min_temp = self.min_temp
|
||||
c.max_temp = self.max_temp
|
||||
c.temp_exponent = self.temp_exponent
|
||||
@@ -220,7 +222,8 @@ class ExLlamaV2Sampler:
|
||||
settings.temperature if settings.temperature_last else 1.0,
|
||||
settings.min_temp,
|
||||
settings.max_temp,
|
||||
settings.temp_exponent)
|
||||
settings.temp_exponent,
|
||||
settings.smoothing_factor)
|
||||
|
||||
if settings.mirostat: settings.mirostat_mu = m
|
||||
|
||||
|
||||
Reference in New Issue
Block a user