Files
tabbyAPI/backends/exllamav3/sampler.py
2026-01-20 22:57:36 +01:00

63 lines
1.6 KiB
Python

from dataclasses import dataclass, field
from typing import List
from exllamav3.generator.sampler import (
CustomSampler,
SS_Temperature,
SS_RepP,
SS_PresFreqP,
SS_Argmax,
SS_MinP,
SS_TopK,
SS_TopP,
SS_Sample,
SS_Base,
SS_AdaptiveP,
)
@dataclass
class ExllamaV3SamplerBuilder:
"""
Custom sampler chain/stack for TabbyAPI
"""
stack: List[SS_Base] = field(default_factory=list)
def penalties(self, rep_p, freq_p, pres_p, penalty_range, rep_decay):
self.stack += [
SS_RepP(rep_p, penalty_range, rep_decay),
SS_PresFreqP(pres_p, freq_p, penalty_range, rep_decay),
]
def temperature(self, temp):
self.stack.append(SS_Temperature(temp))
def top_k(self, top_k):
self.stack.append(SS_TopK(top_k))
def top_p(self, top_p):
self.stack.append(SS_TopP(top_p))
def min_p(self, min_p):
self.stack.append(SS_MinP(min_p))
def greedy(self):
self.stack.append(SS_Argmax())
def adaptive_p(self, adaptive_target, adaptive_decay):
self.stack.append(SS_AdaptiveP(adaptive_target, adaptive_decay))
def build(self, greedy):
"""Builds the final sampler from stack."""
# Adaptive-P does categorical sampling already
if len(self.stack) and isinstance(self.stack[-1], SS_AdaptiveP):
return CustomSampler(self.stack)
# Use greedy if temp is 0
if greedy:
return CustomSampler([SS_Argmax()])
else:
self.stack.append(SS_Sample())
return CustomSampler(self.stack)