Kalomaze's Quadratic Sampling

Quadratic Sampling
This commit is contained in:
Alexander Abushady
2024-02-01 00:11:44 -05:00
parent 9c3fd9df3a
commit 8461e6fa76
5 changed files with 37 additions and 7 deletions

View File

@@ -36,6 +36,7 @@ parser.add_argument("-bn", "--botname", type = str, default = "Chatbort", help =
parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt")
parser.add_argument("-temp", "--temperature", type = float, default = 0.95, help = "Sampler temperature, default = 0.95 (1 to disable)")
parser.add_argument("-smthfctr", "--smoothing_factor", type = float, default = 0.0, help = "Smoothing Factor, default = 0.0 (0 to disable")
parser.add_argument("-dyntemp", "--dynamic_temperature", type = str, help = "Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1")
parser.add_argument("-topk", "--top_k", type = int, default = 50, help = "Sampler top-K, default = 50 (0 to disable)")
parser.add_argument("-topp", "--top_p", type = float, default = 0.8, help = "Sampler top-P, default = 0.8 (0 to disable)")
@@ -196,6 +197,7 @@ settings.typical = args.typical
settings.token_repetition_penalty = args.repetition_penalty
settings.token_frequency_penalty = args.frequency_penalty
settings.token_presence_penalty = args.presence_penalty
settings.smoothing_factor = args.smoothing_factor
if args.dynamic_temperature:
dt_args = [float(alloc) for alloc in args.dynamic_temperature.split(",")]

View File

@@ -99,15 +99,37 @@ void softmax_cpu
(
const int vocab_size,
const float temperature,
const float* logits,
float* logits,
const bool* logits_filter,
float* output
float* output,
float smoothing_factor
)
{
float esum = 0.0f;
float itemp = 1.0f / temperature;
float maxl = -1e38;
// Apply the quadratic transformation to the logits
if (smoothing_factor > 0.0f)
{
// Calculate maxl as the maximum logit value
float maxl = -1e38;
for (int i = 0; i < vocab_size; i++)
{
if (!logits_filter[i]) continue;
maxl = fmaxf(logits[i], maxl);
}
for (int i = 0; i < vocab_size; i++)
{
if (!logits_filter[i]) continue;
float logit_shifted = logits[i] - maxl;
logits[i] = -smoothing_factor * logit_shifted * logit_shifted + maxl;
// Limit the range of logits to prevent extreme values
logits[i] = fminf(fmaxf(logits[i], -1e20), 1e20);
}
}
#pragma unroll(32)
for (int i = 0; i < vocab_size; i++)
{

View File

@@ -23,9 +23,10 @@ void softmax_cpu
(
const int vocab_size,
const float temperature,
const float* logits,
float* logits,
const bool* logits_filter,
float* output
float* output,
float smoothing_factor
);
void normalize_cpu

View File

@@ -987,7 +987,8 @@ std::vector<float> sample_basic
float post_temperature,
float min_temp = 0,
float max_temp = 0.0f,
float temp_exponent = 1.0f
float temp_exponent = 1.0f,
float smoothing_factor = 0.0f
)
{
TORCH_CHECK_DTYPE(logits, kFloat);
@@ -1025,7 +1026,8 @@ std::vector<float> sample_basic
temperature,
logits_ptr + i * vocab_size,
logits_filter_ptr + i * vocab_size,
temp_probs
temp_probs,
smoothing_factor
);
if (top_k == 1)

View File

@@ -15,6 +15,7 @@ class ExLlamaV2Sampler:
token_presence_penalty = 0.0
temperature = 0.8
smoothing_factor = 0.0
min_temp = 0
max_temp = 0.0
temp_exponent = 1.0
@@ -50,6 +51,7 @@ class ExLlamaV2Sampler:
c.token_presence_penalty = self.token_presence_penalty
c.temperature = self.temperature
c.smoothing_factor = self.smoothing_factor
c.min_temp = self.min_temp
c.max_temp = self.max_temp
c.temp_exponent = self.temp_exponent
@@ -220,7 +222,8 @@ class ExLlamaV2Sampler:
settings.temperature if settings.temperature_last else 1.0,
settings.min_temp,
settings.max_temp,
settings.temp_exponent)
settings.temp_exponent,
settings.smoothing_factor)
if settings.mirostat: settings.mirostat_mu = m