mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
Add Adaptive-P sampler
This commit is contained in:
@@ -29,6 +29,7 @@
|
||||
|
||||
#include "generator/strings.h"
|
||||
#include "generator/sampling_basic.cuh"
|
||||
#include "generator/sampling_extra.cuh"
|
||||
#include "generator/gumbel.cuh"
|
||||
#include "generator/rep_pen.cuh"
|
||||
#include "generator/cache.cuh"
|
||||
@@ -115,6 +116,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
m.def("gumbel_noise_log", &gumbel_noise_log, "gumbel_noise_log");
|
||||
m.def("apply_rep_pens", &apply_rep_pens, "apply_rep_pens");
|
||||
m.def("apply_pres_freq_pens", &apply_pres_freq_pens, "apply_pres_freq_pens");
|
||||
m.def("adaptivep_gumbel_noise_f32", &adaptivep_gumbel_noise_f32, "adaptivep_gumbel_noise_f32");
|
||||
|
||||
m.def("cache_rotate", &cache_rotate, "cache_rotate");
|
||||
|
||||
|
||||
87
exllamav3/exllamav3_ext/generator/sampling_extra.cu
Normal file
87
exllamav3/exllamav3_ext/generator/sampling_extra.cu
Normal file
@@ -0,0 +1,87 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include "sampling_basic.cuh"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "../util.h"
|
||||
#include "../util.cuh"
|
||||
#include <limits>
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#define NUM_THREADS 1024
|
||||
|
||||
inline __device__ float gumbel(float x)
|
||||
{
|
||||
return -__logf(fmaxf(-__logf(fmaxf(x, 1e-20)), 1e-20));
|
||||
}
|
||||
|
||||
constexpr float NEG_INF_F32 = -std::numeric_limits<float>::infinity();
|
||||
|
||||
__global__ __launch_bounds__(NUM_THREADS)
|
||||
void adaptivep_gumbel_noise_kernel_f32
|
||||
(
|
||||
const float* __restrict__ probs_in,
|
||||
float* __restrict__ logits,
|
||||
const int size,
|
||||
const uint32_t random,
|
||||
float adapted_target,
|
||||
float inv_width,
|
||||
float peak_logit_value,
|
||||
float sharpness
|
||||
)
|
||||
{
|
||||
int idx = threadIdx.x + NUM_THREADS * blockIdx.x;
|
||||
if (idx >= size) return;
|
||||
|
||||
float x = probs_in[idx];
|
||||
if (x < 1e-8)
|
||||
{
|
||||
x = NEG_INF_F32;
|
||||
}
|
||||
else
|
||||
{
|
||||
curandStatePhilox4_32_10_t state;
|
||||
curand_init(random, idx, 0, &state);
|
||||
|
||||
float adapted_prob = fabs(x - adapted_target) * inv_width;
|
||||
x = peak_logit_value - sharpness * adapted_prob * adapted_prob / (adapted_prob + 1.0);
|
||||
float rf = curand_uniform(&state);
|
||||
x += gumbel(rf);
|
||||
}
|
||||
|
||||
logits[idx] = x;
|
||||
}
|
||||
|
||||
// Produces adaptive-P faux-logits from truncated probabilities, then adds gumbel noise
|
||||
|
||||
void adaptivep_gumbel_noise_f32
|
||||
(
|
||||
const at::Tensor& probs_in,
|
||||
at::Tensor& logits,
|
||||
uint32_t random,
|
||||
float adapted_target,
|
||||
float inv_width,
|
||||
float peak_logit_value,
|
||||
float sharpness
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(logits.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK_DTYPE(probs_in, kFloat);
|
||||
TORCH_CHECK_DTYPE(logits, kFloat);
|
||||
|
||||
int size = logits.numel();
|
||||
int blocks = CEIL_DIVIDE(size, NUM_THREADS);
|
||||
|
||||
adaptivep_gumbel_noise_kernel_f32<<<blocks, NUM_THREADS, 0, stream>>>
|
||||
(
|
||||
(const float*) probs_in.data_ptr(),
|
||||
(float*) logits.data_ptr(),
|
||||
size,
|
||||
random,
|
||||
adapted_target,
|
||||
inv_width,
|
||||
peak_logit_value,
|
||||
sharpness
|
||||
);
|
||||
}
|
||||
14
exllamav3/exllamav3_ext/generator/sampling_extra.cuh
Normal file
14
exllamav3/exllamav3_ext/generator/sampling_extra.cuh
Normal file
@@ -0,0 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
|
||||
void adaptivep_gumbel_noise_f32
|
||||
(
|
||||
const at::Tensor& probs_in,
|
||||
at::Tensor& logits,
|
||||
uint32_t random,
|
||||
float adapted_target,
|
||||
float inv_width,
|
||||
float peak_logit_value,
|
||||
float sharpness
|
||||
);
|
||||
@@ -15,6 +15,7 @@ from .custom import (
|
||||
SS_NoOp,
|
||||
SS_RepP,
|
||||
SS_PresFreqP,
|
||||
SS_AdaptiveP,
|
||||
)
|
||||
from .presets import (
|
||||
DefaultSampler,
|
||||
@@ -25,4 +26,5 @@ from .presets import (
|
||||
TopKSampler,
|
||||
TopPSampler,
|
||||
ComboSampler,
|
||||
AdaptivePSampler,
|
||||
)
|
||||
@@ -21,6 +21,12 @@ class SS(Enum):
|
||||
PROBS_N = 6 # state.probs is valid and normalized
|
||||
PROBS_N_S = 7 # state.probs is valid and normalized, indices are valid
|
||||
|
||||
def clamp(n, smallest, largest):
|
||||
return max(smallest, min(n, largest))
|
||||
|
||||
def conditional(condition, a, b):
|
||||
return a if condition else b
|
||||
|
||||
@dataclass
|
||||
class SamplingState:
|
||||
rand_u32: int
|
||||
@@ -448,6 +454,84 @@ class SS_PresFreqP(SS_Base):
|
||||
return True
|
||||
|
||||
|
||||
class SS_AdaptiveP(SS_Base):
|
||||
"""
|
||||
Implements Adaptive-P sampler. Maintains state but does not remember past states (keeps future state in case
|
||||
of rollback).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
target: float = 1.0,
|
||||
decay: float = 0.0
|
||||
):
|
||||
self.target = target
|
||||
self.decay = decay
|
||||
clamped_decay = max(min(decay, 0.99), 0.0)
|
||||
self.weighted_sum = target / (1.0 - clamped_decay)
|
||||
self.total_weight = 1.0 / (1.0 - clamped_decay)
|
||||
|
||||
self.DISTRIBUTION_WIDTH = 0.3
|
||||
self.PEAK_LOGIT_VALUE = 5.0
|
||||
self.SHARPNESS = 10.0
|
||||
self.INV_WIDTH = 1.0 / self.DISTRIBUTION_WIDTH
|
||||
|
||||
# self.log = []
|
||||
|
||||
def run(self, state: SamplingState):
|
||||
match state.state:
|
||||
case SS.PROBS_N_S:
|
||||
target = clamp(self.target, 0.0, 1.0)
|
||||
adapted_target = conditional(
|
||||
self.total_weight == 0.0,
|
||||
target,
|
||||
2.0 * target - (self.weighted_sum / self.total_weight)
|
||||
)
|
||||
adapted_target = clamp(adapted_target, 0.0, 1.0)
|
||||
|
||||
state.logits = torch.empty_like(state.in_logits, dtype = torch.float)
|
||||
ext.adaptivep_gumbel_noise_f32(
|
||||
state.probs,
|
||||
state.logits,
|
||||
state.rand_u32,
|
||||
adapted_target,
|
||||
self.INV_WIDTH,
|
||||
self.PEAK_LOGIT_VALUE,
|
||||
self.SHARPNESS
|
||||
)
|
||||
|
||||
temp = torch.argmax(state.logits, dim = -1)
|
||||
state.sample = state.indices[buffered_arange(state.bsz, state.in_logits.device), temp]
|
||||
sampled_prob = state.probs[0, temp].item()
|
||||
|
||||
# self.log.append((adapted_target, sampled_prob))
|
||||
# if len(self.log) == 300:
|
||||
# print("\n\n\n")
|
||||
# s = 0
|
||||
# for i, (a, b) in enumerate(self.log):
|
||||
# s += b
|
||||
# m = s / (i + 1)
|
||||
# print(f"{i};{a};{b};{m}")
|
||||
# print("\n\n\n")
|
||||
|
||||
self.weighted_sum = sampled_prob + self.decay * self.weighted_sum
|
||||
self.total_weight = 1.0 + self.decay * self.total_weight
|
||||
case _:
|
||||
raise ValueError("Sampling logic error")
|
||||
state.state = SS.DONE
|
||||
|
||||
def prep(self, in_state: SS):
|
||||
match in_state:
|
||||
case SS.INIT | SS.LOGITS | SS.PROBS | SS.LOGITS_S | SS.PROBS_S:
|
||||
return [SS_Normalize, SS_Sort]
|
||||
case _:
|
||||
return None
|
||||
|
||||
def alt(self):
|
||||
if self.target == 1.0:
|
||||
return SS_NoOp()
|
||||
return None
|
||||
|
||||
|
||||
class CustomSampler(Sampler):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -95,6 +95,8 @@ class ComboSampler(CustomSampler):
|
||||
top_k: int = 0,
|
||||
top_p: float = 1.0,
|
||||
temp_last: bool = False,
|
||||
adaptive_target: float = 1.0,
|
||||
adaptive_decay: float = 0.9,
|
||||
):
|
||||
# Steps with default parameters become no-ops
|
||||
stack = [
|
||||
@@ -113,7 +115,31 @@ class ComboSampler(CustomSampler):
|
||||
SS_TopK(top_k),
|
||||
SS_TopP(top_p),
|
||||
SS_Temperature(temperature if temp_last else 1.0),
|
||||
]
|
||||
|
||||
if adaptive_target != 1.0:
|
||||
stack += [
|
||||
SS_AdaptiveP(adaptive_target, adaptive_decay)
|
||||
]
|
||||
else:
|
||||
stack += [
|
||||
SS_Sample()
|
||||
]
|
||||
|
||||
super().__init__(stack)
|
||||
super().__init__(stack)
|
||||
|
||||
class AdaptivePSampler(CustomSampler):
|
||||
"""
|
||||
Min-P followed by Adaptive-P
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
min_p: float = 0.08,
|
||||
target: float = 0.5,
|
||||
decay: float = 0.9,
|
||||
):
|
||||
stack = [
|
||||
SS_MinP(min_p),
|
||||
SS_AdaptiveP(target, decay)
|
||||
]
|
||||
super().__init__(stack)
|
||||
|
||||
@@ -58,6 +58,8 @@ def add_args(
|
||||
d.min_p = defs.get("min_p", 0.08)
|
||||
d.top_k = defs.get("top_k", 0)
|
||||
d.top_p = defs.get("top_p", 1.0)
|
||||
d.adaptive_target = defs.get("adaptive_target", 1.0)
|
||||
d.adaptive_decay = defs.get("adaptive_decay", 0.9)
|
||||
parser.add_argument("-temp", "--temperature", type = float, help = f"Sampling temperature (default: {d.temperature:.1f})", default = d.temperature)
|
||||
parser.add_argument("-temp_first", "--temperature_first", action = "store_true", help = "Apply temperature before truncation")
|
||||
parser.add_argument("-repp", "--repetition_penalty", type = float, help = f"Repetition penalty, HF style, 1 to disable (default: {d.repetition_penalty:.1f})", default = d.repetition_penalty)
|
||||
@@ -67,6 +69,8 @@ def add_args(
|
||||
parser.add_argument("-minp", "--min_p", type = float, help = f"Min-P truncation, 0 to disable (default: {d.min_p:.2f})", default = d.min_p)
|
||||
parser.add_argument("-topk", "--top_k", type = int, help = f"Top-K truncation, 0 to disable (default: {d.top_k})", default = d.top_k)
|
||||
parser.add_argument("-topp", "--top_p", type = float, help = f"Top-P truncation, 1 to disable (default: {d.top_p:.2f})", default = d.top_p)
|
||||
parser.add_argument("-adaptive_target", "--adaptive_target", type = float, help = f"Adaptive-P target, 1 to disable (default: {d.adaptive_target:.2f})", default = d.adaptive_target)
|
||||
parser.add_argument("-adaptive_decay", "--adaptive_decay", type = float, help = f"Adaptive-P decay, if Adaptive-P enabled (default: {d.adaptive_decay:.2f})", default = d.adaptive_decay)
|
||||
|
||||
if cache:
|
||||
parser.add_argument("-cs", "--cache_size", type = int, help = f"Total cache size in tokens, default: {default_cache_size}", default = default_cache_size)
|
||||
@@ -94,6 +98,8 @@ def get_arg_sampler(args):
|
||||
top_k = args.top_k,
|
||||
top_p = args.top_p,
|
||||
temp_last = not args.temperature_first,
|
||||
adaptive_target = args.adaptive_target,
|
||||
adaptive_decay = args.adaptive_decay,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user