add legacy config converter

This commit is contained in:
TerminalMan
2024-09-16 14:12:47 +01:00
parent b6dd21f737
commit 564bdcf0a8
4 changed files with 71 additions and 19 deletions

View File

@@ -4,7 +4,7 @@ import argparse
from pydantic import BaseModel from pydantic import BaseModel
from common.config_models import TabbyConfigModel from common.config_models import TabbyConfigModel
from common.utils import is_list_type from common.utils import is_list_type, unwrap_optional
def add_field_to_group(group, field_name, field_type, field) -> None: def add_field_to_group(group, field_name, field_type, field) -> None:
@@ -32,7 +32,7 @@ def init_argparser() -> argparse.ArgumentParser:
# Loop through each top-level field in the config # Loop through each top-level field in the config
for field_name, field_info in TabbyConfigModel.model_fields.items(): for field_name, field_info in TabbyConfigModel.model_fields.items():
field_type = field_info.annotation field_type = unwrap_optional(field_info.annotation)
group = parser.add_argument_group( group = parser.add_argument_group(
field_name, description=f"Arguments for {field_name}" field_name, description=f"Arguments for {field_name}"
) )

View File

@@ -439,22 +439,32 @@ class DeveloperConfig(BaseConfigModel):
class TabbyConfigModel(BaseModel): class TabbyConfigModel(BaseModel):
"""Base model for a TabbyConfig.""" """Base model for a TabbyConfig."""
config: ConfigOverrideConfig = Field( config: Optional[ConfigOverrideConfig] = Field(
default_factory=ConfigOverrideConfig.model_construct default_factory=ConfigOverrideConfig.model_construct
) )
network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct) network: Optional[NetworkConfig] = Field(
logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct) default_factory=NetworkConfig.model_construct
model: ModelConfig = Field(default_factory=ModelConfig.model_construct) )
draft_model: DraftModelConfig = Field( logging: Optional[LoggingConfig] = Field(
default_factory=LoggingConfig.model_construct
)
model: Optional[ModelConfig] = Field(default_factory=ModelConfig.model_construct)
draft_model: Optional[DraftModelConfig] = Field(
default_factory=DraftModelConfig.model_construct default_factory=DraftModelConfig.model_construct
) )
lora: LoraConfig = Field(default_factory=LoraConfig.model_construct) lora: Optional[LoraConfig] = Field(default_factory=LoraConfig.model_construct)
embeddings: EmbeddingsConfig = Field( embeddings: Optional[EmbeddingsConfig] = Field(
default_factory=EmbeddingsConfig.model_construct default_factory=EmbeddingsConfig.model_construct
) )
sampling: SamplingConfig = Field(default_factory=SamplingConfig.model_construct) sampling: Optional[SamplingConfig] = Field(
developer: DeveloperConfig = Field(default_factory=DeveloperConfig.model_construct) default_factory=SamplingConfig.model_construct
actions: UtilityActions = Field(default_factory=UtilityActions.model_construct) )
developer: Optional[DeveloperConfig] = Field(
default_factory=DeveloperConfig.model_construct
)
actions: Optional[UtilityActions] = Field(
default_factory=UtilityActions.model_construct
)
model_config = ConfigDict(validate_assignment=True, protected_namespaces=()) model_config = ConfigDict(validate_assignment=True, protected_namespaces=())

View File

@@ -2,10 +2,10 @@ import yaml
import pathlib import pathlib
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from os import getenv from os import getenv, replace
from common.utils import unwrap, merge_dicts from common.utils import unwrap, merge_dicts
from common.config_models import TabbyConfigModel from common.config_models import TabbyConfigModel, generate_config_file
class TabbyConfig(TabbyConfigModel): class TabbyConfig(TabbyConfigModel):
@@ -46,10 +46,25 @@ class TabbyConfig(TabbyConfigModel):
def _from_file(self, config_path: pathlib.Path): def _from_file(self, config_path: pathlib.Path):
"""loads config from a given file path""" """loads config from a given file path"""
legacy = False
cfg = {}
# try loading from file # try loading from file
try: try:
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file: with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
return unwrap(yaml.safe_load(config_file), {}) cfg = yaml.safe_load(config_file)
# FIXME: remove legacy config mapper
# load legacy config files
model = cfg.get("model", {})
if model.get("draft"):
legacy = True
cfg["draft"] = model["draft"]
if model.get("lora"):
legacy = True
cfg["lora"] = model["lora"]
except FileNotFoundError: except FileNotFoundError:
logger.info(f"The '{config_path.name}' file cannot be found") logger.info(f"The '{config_path.name}' file cannot be found")
except Exception as exc: except Exception as exc:
@@ -58,8 +73,21 @@ class TabbyConfig(TabbyConfigModel):
f"the following error:\n\n{exc}" f"the following error:\n\n{exc}"
) )
# if no config file was loaded if legacy:
return {} logger.warning(
"legacy config.yml files are deprecated"
"Please upadte to the new version"
"Attempting auto migrationy"
)
new_cfg = TabbyConfigModel.model_validate(cfg)
try:
replace(config_path, f"{config_path}.bak")
generate_config_file(model=new_cfg, filename=config_path)
except Exception as e:
logger.error(f"Auto migration failed: {e}")
return unwrap(cfg, {})
def _from_args(self, args: dict): def _from_args(self, args: dict):
"""loads config from the provided arguments""" """loads config from the provided arguments"""

View File

@@ -1,6 +1,7 @@
"""Common utility functions""" """Common utility functions"""
from typing import get_args, get_origin from types import NoneType
from typing import Optional, Type, Union, get_args, get_origin
def unwrap(wrapped, default=None): def unwrap(wrapped, default=None):
@@ -47,7 +48,7 @@ def flat_map(input_list):
return [item for sublist in input_list for item in sublist] return [item for sublist in input_list for item in sublist]
def is_list_type(type_hint): def is_list_type(type_hint) -> bool:
"""Checks if a type contains a list.""" """Checks if a type contains a list."""
if get_origin(type_hint) is list: if get_origin(type_hint) is list:
@@ -59,3 +60,16 @@ def is_list_type(type_hint):
return any(is_list_type(arg) for arg in type_args) return any(is_list_type(arg) for arg in type_args)
return False return False
def unwrap_optional(type_hint) -> Type:
"""unwrap Optional[type] annotations"""
if get_origin(type_hint) is Union:
args = get_args(type_hint)
if NoneType in args:
for arg in args:
if arg is not NoneType:
return arg
return type_hint