Config: Allow existing values to get included in generated file

Allows for generation from an existing config file. Primarily used
for migration purposes.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-09-16 12:19:58 -04:00
parent 7f03003437
commit 81ae461eb8
3 changed files with 20 additions and 8 deletions

View File

@@ -4,7 +4,7 @@ import argparse
from pydantic import BaseModel from pydantic import BaseModel
from common.config_models import TabbyConfigModel from common.config_models import TabbyConfigModel
from common.utils import is_list_type, unwrap_optional from common.utils import is_list_type, unwrap_optional_type
def add_field_to_group(group, field_name, field_type, field) -> None: def add_field_to_group(group, field_name, field_type, field) -> None:
@@ -32,7 +32,7 @@ def init_argparser() -> argparse.ArgumentParser:
# Loop through each top-level field in the config # Loop through each top-level field in the config
for field_name, field_info in TabbyConfigModel.model_fields.items(): for field_name, field_info in TabbyConfigModel.model_fields.items():
field_type = unwrap_optional(field_info.annotation) field_type = unwrap_optional_type(field_info.annotation)
group = parser.add_argument_group( group = parser.add_argument_group(
field_name, description=f"Arguments for {field_name}" field_name, description=f"Arguments for {field_name}"
) )

View File

@@ -1,10 +1,11 @@
from inspect import getdoc from inspect import getdoc
from pathlib import Path from pathlib import Path
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from pydantic_core import PydanticUndefined
from textwrap import dedent from textwrap import dedent
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
from pydantic_core import PydanticUndefined from common.utils import unwrap
CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
@@ -488,12 +489,17 @@ def generate_config_file(
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.\n # You can use https://www.yamllint.com/ if you want to check your YAML formatting.\n
""") """)
schema = model if model else TabbyConfigModel() schema = unwrap(model, TabbyConfigModel())
# TODO: Make the disordered iteration look cleaner # TODO: Make the disordered iteration look cleaner
iter_once = False iter_once = False
for field, field_data in schema.model_fields.items(): for field, field_data in schema.model_fields.items():
subfield_model = field_data.default_factory() # Fetch from the existing model class if it's passed
# Probably can use this on schema too, but play it safe
if model:
subfield_model = getattr(model, field, None)
else:
subfield_model = field_data.default_factory()
if not subfield_model._metadata.include_in_config: if not subfield_model._metadata.include_in_config:
continue continue
@@ -519,7 +525,10 @@ def generate_config_file(
else: else:
sub_iter_once = True sub_iter_once = True
if subfield_data.default_factory: # If a value already exists, use it
if hasattr(subfield_model, subfield):
value = getattr(subfield_model, subfield)
elif subfield_data.default_factory:
value = subfield_data.default_factory() value = subfield_data.default_factory()
else: else:
value = subfield_data.default value = subfield_data.default

View File

@@ -62,8 +62,11 @@ def is_list_type(type_hint) -> bool:
return False return False
def unwrap_optional(type_hint) -> Type: def unwrap_optional_type(type_hint) -> Type:
"""unwrap Optional[type] annotations""" """
Unwrap Optional[type] annotations.
This is not the same as unwrap.
"""
if get_origin(type_hint) is Union: if get_origin(type_hint) is Union:
args = get_args(type_hint) args = get_args(type_hint)