mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
Tree: Update to cleanup globals
Use the module singleton pattern to share global state. This can also be a modified version of the Global Object Pattern. The main reason this pattern is used is for ease of use when handling global state rather than adding extra dependencies for a DI parameter. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -274,32 +274,28 @@ class BaseSamplerRequest(BaseModel):
|
||||
|
||||
|
||||
# Global for default overrides
|
||||
DEFAULT_OVERRIDES = {}
|
||||
overrides = {}
|
||||
|
||||
|
||||
def get_sampler_overrides():
|
||||
return DEFAULT_OVERRIDES
|
||||
|
||||
|
||||
def set_overrides_from_dict(new_overrides: dict):
|
||||
def overrides_from_dict(new_overrides: dict):
|
||||
"""Wrapper function to update sampler overrides"""
|
||||
|
||||
global DEFAULT_OVERRIDES
|
||||
global overrides
|
||||
|
||||
if isinstance(new_overrides, dict):
|
||||
DEFAULT_OVERRIDES = prune_dict(new_overrides)
|
||||
overrides = prune_dict(new_overrides)
|
||||
else:
|
||||
raise TypeError("New sampler overrides must be a dict!")
|
||||
|
||||
|
||||
def set_overrides_from_file(preset_name: str):
|
||||
def overrides_from_file(preset_name: str):
|
||||
"""Fetches an override preset from a file"""
|
||||
|
||||
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
|
||||
if preset_path.exists():
|
||||
with open(preset_path, "r", encoding="utf8") as raw_preset:
|
||||
preset = yaml.safe_load(raw_preset)
|
||||
set_overrides_from_dict(preset)
|
||||
overrides_from_dict(preset)
|
||||
|
||||
logger.info("Applied sampler overrides from file.")
|
||||
else:
|
||||
@@ -316,13 +312,13 @@ def set_overrides_from_file(preset_name: str):
|
||||
def get_default_sampler_value(key, fallback=None):
|
||||
"""Gets an overridden default sampler value"""
|
||||
|
||||
return unwrap(DEFAULT_OVERRIDES.get(key, {}).get("override"), fallback)
|
||||
return unwrap(overrides.get(key, {}).get("override"), fallback)
|
||||
|
||||
|
||||
def apply_forced_sampler_overrides(params: BaseSamplerRequest):
|
||||
"""Forcefully applies overrides if specified by the user"""
|
||||
|
||||
for var, value in DEFAULT_OVERRIDES.items():
|
||||
for var, value in overrides.items():
|
||||
override = value.get("override")
|
||||
force = unwrap(value.get("force"), False)
|
||||
if force and override:
|
||||
|
||||
Reference in New Issue
Block a user