mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
63 lines
1.6 KiB
Python
63 lines
1.6 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import List
|
|
from exllamav3.generator.sampler import (
|
|
CustomSampler,
|
|
SS_Temperature,
|
|
SS_RepP,
|
|
SS_PresFreqP,
|
|
SS_Argmax,
|
|
SS_MinP,
|
|
SS_TopK,
|
|
SS_TopP,
|
|
SS_Sample,
|
|
SS_Base,
|
|
SS_AdaptiveP,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ExllamaV3SamplerBuilder:
|
|
"""
|
|
Custom sampler chain/stack for TabbyAPI
|
|
"""
|
|
|
|
stack: List[SS_Base] = field(default_factory=list)
|
|
|
|
def penalties(self, rep_p, freq_p, pres_p, penalty_range, rep_decay):
|
|
self.stack += [
|
|
SS_RepP(rep_p, penalty_range, rep_decay),
|
|
SS_PresFreqP(pres_p, freq_p, penalty_range, rep_decay),
|
|
]
|
|
|
|
def temperature(self, temp):
|
|
self.stack.append(SS_Temperature(temp))
|
|
|
|
def top_k(self, top_k):
|
|
self.stack.append(SS_TopK(top_k))
|
|
|
|
def top_p(self, top_p):
|
|
self.stack.append(SS_TopP(top_p))
|
|
|
|
def min_p(self, min_p):
|
|
self.stack.append(SS_MinP(min_p))
|
|
|
|
def greedy(self):
|
|
self.stack.append(SS_Argmax())
|
|
|
|
def adaptive_p(self, adaptive_target, adaptive_decay):
|
|
self.stack.append(SS_AdaptiveP(adaptive_target, adaptive_decay))
|
|
|
|
def build(self, greedy):
|
|
"""Builds the final sampler from stack."""
|
|
|
|
# Adaptive-P does categorical sampling already
|
|
if len(self.stack) and isinstance(self.stack[-1], SS_AdaptiveP):
|
|
return CustomSampler(self.stack)
|
|
|
|
# Use greedy if temp is 0
|
|
if greedy:
|
|
return CustomSampler([SS_Argmax()])
|
|
else:
|
|
self.stack.append(SS_Sample())
|
|
return CustomSampler(self.stack)
|