mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-29 02:31:48 +00:00
add legacy config converter
This commit is contained in:
@@ -4,7 +4,7 @@ import argparse
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from common.config_models import TabbyConfigModel
|
from common.config_models import TabbyConfigModel
|
||||||
from common.utils import is_list_type
|
from common.utils import is_list_type, unwrap_optional
|
||||||
|
|
||||||
|
|
||||||
def add_field_to_group(group, field_name, field_type, field) -> None:
|
def add_field_to_group(group, field_name, field_type, field) -> None:
|
||||||
@@ -32,7 +32,7 @@ def init_argparser() -> argparse.ArgumentParser:
|
|||||||
|
|
||||||
# Loop through each top-level field in the config
|
# Loop through each top-level field in the config
|
||||||
for field_name, field_info in TabbyConfigModel.model_fields.items():
|
for field_name, field_info in TabbyConfigModel.model_fields.items():
|
||||||
field_type = field_info.annotation
|
field_type = unwrap_optional(field_info.annotation)
|
||||||
group = parser.add_argument_group(
|
group = parser.add_argument_group(
|
||||||
field_name, description=f"Arguments for {field_name}"
|
field_name, description=f"Arguments for {field_name}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -439,22 +439,32 @@ class DeveloperConfig(BaseConfigModel):
|
|||||||
class TabbyConfigModel(BaseModel):
|
class TabbyConfigModel(BaseModel):
|
||||||
"""Base model for a TabbyConfig."""
|
"""Base model for a TabbyConfig."""
|
||||||
|
|
||||||
config: ConfigOverrideConfig = Field(
|
config: Optional[ConfigOverrideConfig] = Field(
|
||||||
default_factory=ConfigOverrideConfig.model_construct
|
default_factory=ConfigOverrideConfig.model_construct
|
||||||
)
|
)
|
||||||
network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct)
|
network: Optional[NetworkConfig] = Field(
|
||||||
logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct)
|
default_factory=NetworkConfig.model_construct
|
||||||
model: ModelConfig = Field(default_factory=ModelConfig.model_construct)
|
)
|
||||||
draft_model: DraftModelConfig = Field(
|
logging: Optional[LoggingConfig] = Field(
|
||||||
|
default_factory=LoggingConfig.model_construct
|
||||||
|
)
|
||||||
|
model: Optional[ModelConfig] = Field(default_factory=ModelConfig.model_construct)
|
||||||
|
draft_model: Optional[DraftModelConfig] = Field(
|
||||||
default_factory=DraftModelConfig.model_construct
|
default_factory=DraftModelConfig.model_construct
|
||||||
)
|
)
|
||||||
lora: LoraConfig = Field(default_factory=LoraConfig.model_construct)
|
lora: Optional[LoraConfig] = Field(default_factory=LoraConfig.model_construct)
|
||||||
embeddings: EmbeddingsConfig = Field(
|
embeddings: Optional[EmbeddingsConfig] = Field(
|
||||||
default_factory=EmbeddingsConfig.model_construct
|
default_factory=EmbeddingsConfig.model_construct
|
||||||
)
|
)
|
||||||
sampling: SamplingConfig = Field(default_factory=SamplingConfig.model_construct)
|
sampling: Optional[SamplingConfig] = Field(
|
||||||
developer: DeveloperConfig = Field(default_factory=DeveloperConfig.model_construct)
|
default_factory=SamplingConfig.model_construct
|
||||||
actions: UtilityActions = Field(default_factory=UtilityActions.model_construct)
|
)
|
||||||
|
developer: Optional[DeveloperConfig] = Field(
|
||||||
|
default_factory=DeveloperConfig.model_construct
|
||||||
|
)
|
||||||
|
actions: Optional[UtilityActions] = Field(
|
||||||
|
default_factory=UtilityActions.model_construct
|
||||||
|
)
|
||||||
|
|
||||||
model_config = ConfigDict(validate_assignment=True, protected_namespaces=())
|
model_config = ConfigDict(validate_assignment=True, protected_namespaces=())
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ import yaml
|
|||||||
import pathlib
|
import pathlib
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from os import getenv
|
from os import getenv, replace
|
||||||
|
|
||||||
from common.utils import unwrap, merge_dicts
|
from common.utils import unwrap, merge_dicts
|
||||||
from common.config_models import TabbyConfigModel
|
from common.config_models import TabbyConfigModel, generate_config_file
|
||||||
|
|
||||||
|
|
||||||
class TabbyConfig(TabbyConfigModel):
|
class TabbyConfig(TabbyConfigModel):
|
||||||
@@ -46,10 +46,25 @@ class TabbyConfig(TabbyConfigModel):
|
|||||||
def _from_file(self, config_path: pathlib.Path):
|
def _from_file(self, config_path: pathlib.Path):
|
||||||
"""loads config from a given file path"""
|
"""loads config from a given file path"""
|
||||||
|
|
||||||
|
legacy = False
|
||||||
|
cfg = {}
|
||||||
|
|
||||||
# try loading from file
|
# try loading from file
|
||||||
try:
|
try:
|
||||||
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
|
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
|
||||||
return unwrap(yaml.safe_load(config_file), {})
|
cfg = yaml.safe_load(config_file)
|
||||||
|
|
||||||
|
# FIXME: remove legacy config mapper
|
||||||
|
# load legacy config files
|
||||||
|
model = cfg.get("model", {})
|
||||||
|
|
||||||
|
if model.get("draft"):
|
||||||
|
legacy = True
|
||||||
|
cfg["draft"] = model["draft"]
|
||||||
|
if model.get("lora"):
|
||||||
|
legacy = True
|
||||||
|
cfg["lora"] = model["lora"]
|
||||||
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.info(f"The '{config_path.name}' file cannot be found")
|
logger.info(f"The '{config_path.name}' file cannot be found")
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -58,8 +73,21 @@ class TabbyConfig(TabbyConfigModel):
|
|||||||
f"the following error:\n\n{exc}"
|
f"the following error:\n\n{exc}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# if no config file was loaded
|
if legacy:
|
||||||
return {}
|
logger.warning(
|
||||||
|
"legacy config.yml files are deprecated"
|
||||||
|
"Please upadte to the new version"
|
||||||
|
"Attempting auto migrationy"
|
||||||
|
)
|
||||||
|
new_cfg = TabbyConfigModel.model_validate(cfg)
|
||||||
|
|
||||||
|
try:
|
||||||
|
replace(config_path, f"{config_path}.bak")
|
||||||
|
generate_config_file(model=new_cfg, filename=config_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Auto migration failed: {e}")
|
||||||
|
|
||||||
|
return unwrap(cfg, {})
|
||||||
|
|
||||||
def _from_args(self, args: dict):
|
def _from_args(self, args: dict):
|
||||||
"""loads config from the provided arguments"""
|
"""loads config from the provided arguments"""
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Common utility functions"""
|
"""Common utility functions"""
|
||||||
|
|
||||||
from typing import get_args, get_origin
|
from types import NoneType
|
||||||
|
from typing import Optional, Type, Union, get_args, get_origin
|
||||||
|
|
||||||
|
|
||||||
def unwrap(wrapped, default=None):
|
def unwrap(wrapped, default=None):
|
||||||
@@ -47,7 +48,7 @@ def flat_map(input_list):
|
|||||||
return [item for sublist in input_list for item in sublist]
|
return [item for sublist in input_list for item in sublist]
|
||||||
|
|
||||||
|
|
||||||
def is_list_type(type_hint):
|
def is_list_type(type_hint) -> bool:
|
||||||
"""Checks if a type contains a list."""
|
"""Checks if a type contains a list."""
|
||||||
|
|
||||||
if get_origin(type_hint) is list:
|
if get_origin(type_hint) is list:
|
||||||
@@ -59,3 +60,16 @@ def is_list_type(type_hint):
|
|||||||
return any(is_list_type(arg) for arg in type_args)
|
return any(is_list_type(arg) for arg in type_args)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap_optional(type_hint) -> Type:
|
||||||
|
"""unwrap Optional[type] annotations"""
|
||||||
|
|
||||||
|
if get_origin(type_hint) is Union:
|
||||||
|
args = get_args(type_hint)
|
||||||
|
if NoneType in args:
|
||||||
|
for arg in args:
|
||||||
|
if arg is not NoneType:
|
||||||
|
return arg
|
||||||
|
|
||||||
|
return type_hint
|
||||||
|
|||||||
Reference in New Issue
Block a user