Add Adaptive-P sampler

This commit is contained in:
turboderp
2026-01-14 21:42:40 +01:00
parent 0d09af403a
commit f21b92e978
7 changed files with 222 additions and 1 deletions

View File

@@ -29,6 +29,7 @@
#include "generator/strings.h"
#include "generator/sampling_basic.cuh"
#include "generator/sampling_extra.cuh"
#include "generator/gumbel.cuh"
#include "generator/rep_pen.cuh"
#include "generator/cache.cuh"
@@ -115,6 +116,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("gumbel_noise_log", &gumbel_noise_log, "gumbel_noise_log");
m.def("apply_rep_pens", &apply_rep_pens, "apply_rep_pens");
m.def("apply_pres_freq_pens", &apply_pres_freq_pens, "apply_pres_freq_pens");
m.def("adaptivep_gumbel_noise_f32", &adaptivep_gumbel_noise_f32, "adaptivep_gumbel_noise_f32");
m.def("cache_rotate", &cache_rotate, "cache_rotate");

View File

@@ -0,0 +1,87 @@
#include <cuda_fp16.h>
#include "sampling_basic.cuh"
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include "../util.h"
#include "../util.cuh"
#include <limits>
#include <curand_kernel.h>
#define NUM_THREADS 1024
inline __device__ float gumbel(float x)
{
return -__logf(fmaxf(-__logf(fmaxf(x, 1e-20)), 1e-20));
}
constexpr float NEG_INF_F32 = -std::numeric_limits<float>::infinity();
__global__ __launch_bounds__(NUM_THREADS)
void adaptivep_gumbel_noise_kernel_f32
(
const float* __restrict__ probs_in,
float* __restrict__ logits,
const int size,
const uint32_t random,
float adapted_target,
float inv_width,
float peak_logit_value,
float sharpness
)
{
int idx = threadIdx.x + NUM_THREADS * blockIdx.x;
if (idx >= size) return;
float x = probs_in[idx];
if (x < 1e-8)
{
x = NEG_INF_F32;
}
else
{
curandStatePhilox4_32_10_t state;
curand_init(random, idx, 0, &state);
float adapted_prob = fabs(x - adapted_target) * inv_width;
x = peak_logit_value - sharpness * adapted_prob * adapted_prob / (adapted_prob + 1.0);
float rf = curand_uniform(&state);
x += gumbel(rf);
}
logits[idx] = x;
}
// Produces adaptive-P faux-logits from truncated probabilities, then adds gumbel noise
void adaptivep_gumbel_noise_f32
(
const at::Tensor& probs_in,
at::Tensor& logits,
uint32_t random,
float adapted_target,
float inv_width,
float peak_logit_value,
float sharpness
)
{
const at::cuda::OptionalCUDAGuard device_guard(logits.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_DTYPE(probs_in, kFloat);
TORCH_CHECK_DTYPE(logits, kFloat);
int size = logits.numel();
int blocks = CEIL_DIVIDE(size, NUM_THREADS);
adaptivep_gumbel_noise_kernel_f32<<<blocks, NUM_THREADS, 0, stream>>>
(
(const float*) probs_in.data_ptr(),
(float*) logits.data_ptr(),
size,
random,
adapted_target,
inv_width,
peak_logit_value,
sharpness
);
}

View File

@@ -0,0 +1,14 @@
#pragma once
#include <ATen/Tensor.h>
void adaptivep_gumbel_noise_f32
(
const at::Tensor& probs_in,
at::Tensor& logits,
uint32_t random,
float adapted_target,
float inv_width,
float peak_logit_value,
float sharpness
);

View File

@@ -15,6 +15,7 @@ from .custom import (
SS_NoOp,
SS_RepP,
SS_PresFreqP,
SS_AdaptiveP,
)
from .presets import (
DefaultSampler,
@@ -25,4 +26,5 @@ from .presets import (
TopKSampler,
TopPSampler,
ComboSampler,
AdaptivePSampler,
)

View File

@@ -21,6 +21,12 @@ class SS(Enum):
PROBS_N = 6 # state.probs is valid and normalized
PROBS_N_S = 7 # state.probs is valid and normalized, indices are valid
def clamp(n, smallest, largest):
return max(smallest, min(n, largest))
def conditional(condition, a, b):
return a if condition else b
@dataclass
class SamplingState:
rand_u32: int
@@ -448,6 +454,84 @@ class SS_PresFreqP(SS_Base):
return True
class SS_AdaptiveP(SS_Base):
"""
Implements Adaptive-P sampler. Maintains state but does not remember past states (keeps future state in case
of rollback).
"""
def __init__(
self,
target: float = 1.0,
decay: float = 0.0
):
self.target = target
self.decay = decay
clamped_decay = max(min(decay, 0.99), 0.0)
self.weighted_sum = target / (1.0 - clamped_decay)
self.total_weight = 1.0 / (1.0 - clamped_decay)
self.DISTRIBUTION_WIDTH = 0.3
self.PEAK_LOGIT_VALUE = 5.0
self.SHARPNESS = 10.0
self.INV_WIDTH = 1.0 / self.DISTRIBUTION_WIDTH
# self.log = []
def run(self, state: SamplingState):
match state.state:
case SS.PROBS_N_S:
target = clamp(self.target, 0.0, 1.0)
adapted_target = conditional(
self.total_weight == 0.0,
target,
2.0 * target - (self.weighted_sum / self.total_weight)
)
adapted_target = clamp(adapted_target, 0.0, 1.0)
state.logits = torch.empty_like(state.in_logits, dtype = torch.float)
ext.adaptivep_gumbel_noise_f32(
state.probs,
state.logits,
state.rand_u32,
adapted_target,
self.INV_WIDTH,
self.PEAK_LOGIT_VALUE,
self.SHARPNESS
)
temp = torch.argmax(state.logits, dim = -1)
state.sample = state.indices[buffered_arange(state.bsz, state.in_logits.device), temp]
sampled_prob = state.probs[0, temp].item()
# self.log.append((adapted_target, sampled_prob))
# if len(self.log) == 300:
# print("\n\n\n")
# s = 0
# for i, (a, b) in enumerate(self.log):
# s += b
# m = s / (i + 1)
# print(f"{i};{a};{b};{m}")
# print("\n\n\n")
self.weighted_sum = sampled_prob + self.decay * self.weighted_sum
self.total_weight = 1.0 + self.decay * self.total_weight
case _:
raise ValueError("Sampling logic error")
state.state = SS.DONE
def prep(self, in_state: SS):
match in_state:
case SS.INIT | SS.LOGITS | SS.PROBS | SS.LOGITS_S | SS.PROBS_S:
return [SS_Normalize, SS_Sort]
case _:
return None
def alt(self):
if self.target == 1.0:
return SS_NoOp()
return None
class CustomSampler(Sampler):
def __init__(
self,

View File

@@ -95,6 +95,8 @@ class ComboSampler(CustomSampler):
top_k: int = 0,
top_p: float = 1.0,
temp_last: bool = False,
adaptive_target: float = 1.0,
adaptive_decay: float = 0.9,
):
# Steps with default parameters become no-ops
stack = [
@@ -113,7 +115,31 @@ class ComboSampler(CustomSampler):
SS_TopK(top_k),
SS_TopP(top_p),
SS_Temperature(temperature if temp_last else 1.0),
]
if adaptive_target != 1.0:
stack += [
SS_AdaptiveP(adaptive_target, adaptive_decay)
]
else:
stack += [
SS_Sample()
]
super().__init__(stack)
super().__init__(stack)
class AdaptivePSampler(CustomSampler):
"""
Min-P followed by Adaptive-P
"""
def __init__(
self,
min_p: float = 0.08,
target: float = 0.5,
decay: float = 0.9,
):
stack = [
SS_MinP(min_p),
SS_AdaptiveP(target, decay)
]
super().__init__(stack)

View File

@@ -58,6 +58,8 @@ def add_args(
d.min_p = defs.get("min_p", 0.08)
d.top_k = defs.get("top_k", 0)
d.top_p = defs.get("top_p", 1.0)
d.adaptive_target = defs.get("adaptive_target", 1.0)
d.adaptive_decay = defs.get("adaptive_decay", 0.9)
parser.add_argument("-temp", "--temperature", type = float, help = f"Sampling temperature (default: {d.temperature:.1f})", default = d.temperature)
parser.add_argument("-temp_first", "--temperature_first", action = "store_true", help = "Apply temperature before truncation")
parser.add_argument("-repp", "--repetition_penalty", type = float, help = f"Repetition penalty, HF style, 1 to disable (default: {d.repetition_penalty:.1f})", default = d.repetition_penalty)
@@ -67,6 +69,8 @@ def add_args(
parser.add_argument("-minp", "--min_p", type = float, help = f"Min-P truncation, 0 to disable (default: {d.min_p:.2f})", default = d.min_p)
parser.add_argument("-topk", "--top_k", type = int, help = f"Top-K truncation, 0 to disable (default: {d.top_k})", default = d.top_k)
parser.add_argument("-topp", "--top_p", type = float, help = f"Top-P truncation, 1 to disable (default: {d.top_p:.2f})", default = d.top_p)
parser.add_argument("-adaptive_target", "--adaptive_target", type = float, help = f"Adaptive-P target, 1 to disable (default: {d.adaptive_target:.2f})", default = d.adaptive_target)
parser.add_argument("-adaptive_decay", "--adaptive_decay", type = float, help = f"Adaptive-P decay, if Adaptive-P enabled (default: {d.adaptive_decay:.2f})", default = d.adaptive_decay)
if cache:
parser.add_argument("-cs", "--cache_size", type = int, help = f"Total cache size in tokens, default: {default_cache_size}", default = default_cache_size)
@@ -94,6 +98,8 @@ def get_arg_sampler(args):
top_k = args.top_k,
top_p = args.top_p,
temp_last = not args.temperature_first,
adaptive_target = args.adaptive_target,
adaptive_decay = args.adaptive_decay,
)