Quadratic Sampling optimizations

This commit is contained in:
Alexander Abushady
2024-02-01 15:01:12 -05:00
parent 8461e6fa76
commit 3ea67828ea
3 changed files with 58 additions and 28 deletions

View File

@@ -95,41 +95,49 @@ void apply_rep_penalty_cpu
}
}
void softmax_cpu
void quadratic_sampling
(
const int vocab_size,
const float temperature,
float* logits,
const bool* logits_filter,
float* output,
float smoothing_factor
float smoothing_factor,
float* output
)
{
// Calculate maxl as the maximum logit value
float maxl = -1e38;
for (int i = 0; i < vocab_size; i++)
{
if (!logits_filter[i]) continue;
maxl = fmaxf(logits[i], maxl);
}
for (int i = 0; i < vocab_size; i++)
{
if (!logits_filter[i]) continue;
float logit_shifted = logits[i] - maxl;
logits[i] = -smoothing_factor * logit_shifted * logit_shifted + maxl;
// Limit the range of logits to prevent extreme values
logits[i] = fminf(fmaxf(logits[i], -1e20), 1e20);
}
softmax_cpu(vocab_size, temperature, logits, logits_filter, output);
}
void softmax_cpu
(
const int vocab_size,
const float temperature,
const float* logits,
const bool* logits_filter,
float* output
)
{
float esum = 0.0f;
float itemp = 1.0f / temperature;
float maxl = -1e38;
// Apply the quadratic transformation to the logits
if (smoothing_factor > 0.0f)
{
// Calculate maxl as the maximum logit value
float maxl = -1e38;
for (int i = 0; i < vocab_size; i++)
{
if (!logits_filter[i]) continue;
maxl = fmaxf(logits[i], maxl);
}
for (int i = 0; i < vocab_size; i++)
{
if (!logits_filter[i]) continue;
float logit_shifted = logits[i] - maxl;
logits[i] = -smoothing_factor * logit_shifted * logit_shifted + maxl;
// Limit the range of logits to prevent extreme values
logits[i] = fminf(fmaxf(logits[i], -1e20), 1e20);
}
}
#pragma unroll(32)
for (int i = 0; i < vocab_size; i++)
{

View File

@@ -20,13 +20,22 @@ void apply_rep_penalty_cpu
);
void softmax_cpu
(
const int vocab_size,
const float temperature,
const float* logits,
const bool* logits_filter,
float* output
);
void quadratic_sampling
(
const int vocab_size,
const float temperature,
float* logits,
const bool* logits_filter,
float* output,
float smoothing_factor
float smoothing_factor,
float* output
);
void normalize_cpu

View File

@@ -1026,8 +1026,7 @@ std::vector<float> sample_basic
temperature,
logits_ptr + i * vocab_size,
logits_filter_ptr + i * vocab_size,
temp_probs,
smoothing_factor
temp_probs
);
if (top_k == 1)
@@ -1065,6 +1064,20 @@ std::vector<float> sample_basic
normalize_cpu(num_candidates, temp_probs);
}
if (smoothing_factor > 0)
{
// Apply quadratic_sampling to the logits
quadratic_sampling
(
num_candidates,
temperature,
temp_probs,
logits_filter_ptr + i * vocab_size,
smoothing_factor,
temp_probs
);
}
if (tfs > 0.0f && tfs < 1.0f)
{
num_candidates = tfs_cpu(num_candidates, temp_probs, temp_indices, tfs);