mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
Sampling: Copy over iterable overrides
If an override was iterable, any modifications to the returned value would alter the reference to the global storage dict. Therefore, copy the structure if it's an iterable so any modification won't alter the original override. Also apply this for the function that checks for forced overrides. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import pathlib
|
||||
import yaml
|
||||
from copy import deepcopy
|
||||
from loguru import logger
|
||||
from pydantic import AliasChoices, BaseModel, Field
|
||||
from typing import Dict, List, Optional, Union
|
||||
@@ -376,14 +377,19 @@ def get_all_presets():
|
||||
def get_default_sampler_value(key, fallback=None):
|
||||
"""Gets an overridden default sampler value"""
|
||||
|
||||
return unwrap(overrides_container.overrides.get(key, {}).get("override"), fallback)
|
||||
default_value = unwrap(
|
||||
deepcopy(overrides_container.overrides.get(key, {}).get("override")),
|
||||
fallback,
|
||||
)
|
||||
|
||||
return default_value
|
||||
|
||||
|
||||
def apply_forced_sampler_overrides(params: BaseSamplerRequest):
|
||||
"""Forcefully applies overrides if specified by the user"""
|
||||
|
||||
for var, value in overrides_container.overrides.items():
|
||||
override = value.get("override")
|
||||
override = deepcopy(value.get("override"))
|
||||
original_value = getattr(params, var, None)
|
||||
|
||||
# Force takes precedence over additive
|
||||
|
||||
@@ -15,6 +15,6 @@ def coalesce(*args):
|
||||
|
||||
|
||||
def prune_dict(input_dict):
|
||||
"""Trim out instances of None from a dictionary"""
|
||||
"""Trim out instances of None from a dictionary."""
|
||||
|
||||
return {k: v for k, v in input_dict.items() if v is not None}
|
||||
|
||||
Reference in New Issue
Block a user