Merge quad sampling into softmax

Combine temperature and smoothing factor (still separate args to sample_basic)
Allow arbitrary exponent
This commit is contained in:
turboderp
2024-02-02 15:07:43 +01:00
parent 9f8951e63b
commit b60c34770e
3 changed files with 36 additions and 77 deletions

View File

@@ -95,42 +95,13 @@ void apply_rep_penalty_cpu
}
}
void quadratic_sampling
(
const int vocab_size,
const float temperature,
float* logits,
const bool* logits_filter,
float smoothing_factor,
float* output
)
{
// Calculate maxl as the maximum logit value
float maxl = -1e38;
for (int i = 0; i < vocab_size; i++)
{
if (!logits_filter[i]) continue;
maxl = fmaxf(logits[i], maxl);
}
for (int i = 0; i < vocab_size; i++)
{
if (!logits_filter[i]) continue;
float logit_shifted = logits[i] - maxl;
logits[i] = -smoothing_factor * logit_shifted * logit_shifted + maxl;
// Limit the range of logits to prevent extreme values
logits[i] = fminf(fmaxf(logits[i], -1e20), 1e20);
}
softmax_cpu(vocab_size, temperature, logits, logits_filter, output);
}
void softmax_cpu
(
const int vocab_size,
const float temperature,
const float* logits,
const bool* logits_filter,
const float exponent,
float* output
)
{
@@ -138,31 +109,31 @@ void softmax_cpu
float itemp = 1.0f / temperature;
float maxl = -1e38;
#pragma unroll(32)
for (int i = 0; i < vocab_size; i++)
{
if (!logits_filter[i]) continue;
maxl = fmaxf(logits[i], maxl);
}
maxl *= itemp;
#pragma unroll(32)
for (int i = 0; i < vocab_size; i++)
{
if (!logits_filter[i]) continue;
float e = expf(logits[i] * itemp - maxl);
float l = logits[i] - maxl;
if (exponent == 2.0f)
l *= -l;
else if (exponent != 1.0f)
l = -powf(fabs(l), exponent);
float e = expf(l * itemp);
output[i] = e;
esum += e;
}
float isum = 1.0f / esum;
#pragma unroll(32)
for (int i = 0; i < vocab_size; i++)
{
if (logits_filter[i])
output[i] *= isum;
else
output[i] = 0.0f;
if (logits_filter[i]) output[i] *= isum;
else output[i] = 0.0f;
}
// printf("Softmax:");
@@ -171,11 +142,12 @@ void softmax_cpu
// {
// if (logits_filter[i])
// {
// printf("%d, %f\n", i, output[i]);
// summ += output[i];
// if (output[i] < 1e-5) continue;
// printf("%d, %f\n", i, output[i]);
// }
// }
// printf("sum: %f\n", summ);
// printf("sum: %f\n\n", summ);
}
int post_softmax_temperature
@@ -769,6 +741,16 @@ int multinomial_cpu
float random
)
{
// printf("\n-----------------\n");
// int j = 0;
// for (int i = 0; i < num_candidates && j < 10; ++i)
// {
// if (temp_probs[i] < 1e-6) continue;
// DBGIF(i, temp_probs[i]);
// j++;
// }
// printf("-----------------\n");
int idx = 0;
float accum = temp_probs[idx];

View File

@@ -25,16 +25,7 @@ void softmax_cpu
const float temperature,
const float* logits,
const bool* logits_filter,
float* output
);
void quadratic_sampling
(
const int vocab_size,
const float temperature,
float* logits,
const bool* logits_filter,
float smoothing_factor,
const float exponent,
float* output
);

View File

@@ -1020,30 +1020,21 @@ std::vector<float> sample_basic
for (int i = 0; i < bsz; i++)
{
float exponent = 1.0f;
if (smoothing_factor > 0)
{
// Apply quadratic_sampling to the logits
quadratic_sampling
(
vocab_size,
temperature,
logits_ptr + i * vocab_size,
logits_filter_ptr + i * vocab_size,
smoothing_factor,
temp_probs
);
}
else
{
softmax_cpu
(
vocab_size,
temperature,
logits_ptr + i * vocab_size,
logits_filter_ptr + i * vocab_size,
temp_probs
);
exponent = 2.0f;
temperature /= smoothing_factor;
}
softmax_cpu
(
vocab_size,
temperature,
logits_ptr + i * vocab_size,
logits_filter_ptr + i * vocab_size,
exponent,
temp_probs
);
if (top_k == 1)
{
@@ -1080,11 +1071,6 @@ std::vector<float> sample_basic
normalize_cpu(num_candidates, temp_probs);
}
if (smoothing_factor > 0)
{
}
if (tfs > 0.0f && tfs < 1.0f)
{
num_candidates = tfs_cpu(num_candidates, temp_probs, temp_indices, tfs);