mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-28 18:21:42 +00:00
Config: Allow existing values to get included in generated file
Allows for generation from an existing config file. Primarily used for migration purposes. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -4,7 +4,7 @@ import argparse
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from common.config_models import TabbyConfigModel
|
from common.config_models import TabbyConfigModel
|
||||||
from common.utils import is_list_type, unwrap_optional
|
from common.utils import is_list_type, unwrap_optional_type
|
||||||
|
|
||||||
|
|
||||||
def add_field_to_group(group, field_name, field_type, field) -> None:
|
def add_field_to_group(group, field_name, field_type, field) -> None:
|
||||||
@@ -32,7 +32,7 @@ def init_argparser() -> argparse.ArgumentParser:
|
|||||||
|
|
||||||
# Loop through each top-level field in the config
|
# Loop through each top-level field in the config
|
||||||
for field_name, field_info in TabbyConfigModel.model_fields.items():
|
for field_name, field_info in TabbyConfigModel.model_fields.items():
|
||||||
field_type = unwrap_optional(field_info.annotation)
|
field_type = unwrap_optional_type(field_info.annotation)
|
||||||
group = parser.add_argument_group(
|
group = parser.add_argument_group(
|
||||||
field_name, description=f"Arguments for {field_name}"
|
field_name, description=f"Arguments for {field_name}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from inspect import getdoc
|
from inspect import getdoc
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||||
|
from pydantic_core import PydanticUndefined
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic_core import PydanticUndefined
|
from common.utils import unwrap
|
||||||
|
|
||||||
CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
|
CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
|
||||||
|
|
||||||
@@ -488,12 +489,17 @@ def generate_config_file(
|
|||||||
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.\n
|
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.\n
|
||||||
""")
|
""")
|
||||||
|
|
||||||
schema = model if model else TabbyConfigModel()
|
schema = unwrap(model, TabbyConfigModel())
|
||||||
|
|
||||||
# TODO: Make the disordered iteration look cleaner
|
# TODO: Make the disordered iteration look cleaner
|
||||||
iter_once = False
|
iter_once = False
|
||||||
for field, field_data in schema.model_fields.items():
|
for field, field_data in schema.model_fields.items():
|
||||||
subfield_model = field_data.default_factory()
|
# Fetch from the existing model class if it's passed
|
||||||
|
# Probably can use this on schema too, but play it safe
|
||||||
|
if model:
|
||||||
|
subfield_model = getattr(model, field, None)
|
||||||
|
else:
|
||||||
|
subfield_model = field_data.default_factory()
|
||||||
|
|
||||||
if not subfield_model._metadata.include_in_config:
|
if not subfield_model._metadata.include_in_config:
|
||||||
continue
|
continue
|
||||||
@@ -519,7 +525,10 @@ def generate_config_file(
|
|||||||
else:
|
else:
|
||||||
sub_iter_once = True
|
sub_iter_once = True
|
||||||
|
|
||||||
if subfield_data.default_factory:
|
# If a value already exists, use it
|
||||||
|
if hasattr(subfield_model, subfield):
|
||||||
|
value = getattr(subfield_model, subfield)
|
||||||
|
elif subfield_data.default_factory:
|
||||||
value = subfield_data.default_factory()
|
value = subfield_data.default_factory()
|
||||||
else:
|
else:
|
||||||
value = subfield_data.default
|
value = subfield_data.default
|
||||||
|
|||||||
@@ -62,8 +62,11 @@ def is_list_type(type_hint) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def unwrap_optional(type_hint) -> Type:
|
def unwrap_optional_type(type_hint) -> Type:
|
||||||
"""unwrap Optional[type] annotations"""
|
"""
|
||||||
|
Unwrap Optional[type] annotations.
|
||||||
|
This is not the same as unwrap.
|
||||||
|
"""
|
||||||
|
|
||||||
if get_origin(type_hint) is Union:
|
if get_origin(type_hint) is Union:
|
||||||
args = get_args(type_hint)
|
args = get_args(type_hint)
|
||||||
|
|||||||
Reference in New Issue
Block a user