migrate to ruamel.yaml

This commit is contained in:
TerminalMan
2024-09-18 01:06:34 +01:00
parent bb4dd7200e
commit 948fcb7f5b
3 changed files with 69 additions and 70 deletions

View File

@@ -1,14 +1,19 @@
import yaml
import pathlib import pathlib
from inspect import getdoc from inspect import getdoc
from pydantic_core import PydanticUndefined
from loguru import logger
from textwrap import dedent
from typing import Optional
from os import getenv from os import getenv
from textwrap import dedent
from typing import Any, Optional
from common.utils import unwrap, merge_dicts from loguru import logger
from common.config_models import BaseConfigModel, TabbyConfigModel from pydantic import BaseModel
from pydantic_core import PydanticUndefined
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap, CommentedSeq
from common.config_models import TabbyConfigModel
from common.utils import merge_dicts, unwrap
yaml = YAML()
class TabbyConfig(TabbyConfigModel): class TabbyConfig(TabbyConfigModel):
@@ -57,7 +62,7 @@ class TabbyConfig(TabbyConfigModel):
# try loading from file # try loading from file
try: try:
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file: with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
cfg = yaml.safe_load(config_file) cfg = yaml.load(config_file)
# NOTE: Remove migration wrapper after a period of time # NOTE: Remove migration wrapper after a period of time
# load legacy config files # load legacy config files
@@ -130,7 +135,7 @@ class TabbyConfig(TabbyConfigModel):
"""loads config from the provided arguments""" """loads config from the provided arguments"""
config = {} config = {}
config_override = unwrap(args.get("options", {}).get("config")) config_override = args.get("options", {}).get("config", None)
if config_override: if config_override:
logger.info("Config file override detected in args.") logger.info("Config file override detected in args.")
config = self._from_file(pathlib.Path(config_override)) config = self._from_file(pathlib.Path(config_override))
@@ -166,15 +171,25 @@ class TabbyConfig(TabbyConfigModel):
config: TabbyConfig = TabbyConfig() config: TabbyConfig = TabbyConfig()
# TODO: Possibly switch to ruamel.yaml for a more native implementation
def generate_config_file( def generate_config_file(
model: BaseConfigModel = None, model: BaseModel = None,
filename: str = "config_sample.yml", filename: str = "config_sample.yml",
indentation: int = 2, indentation: int = 2,
) -> None: ) -> None:
"""Creates a config.yml file from Pydantic models.""" """Creates a config.yml file from Pydantic models."""
# Add a cleaned up preamble schema = unwrap(model, TabbyConfigModel())
preamble = get_preamble()
yaml_content = pydantic_model_to_yaml(schema)
with open(filename, "w") as f:
f.write(preamble)
yaml.dump(yaml_content, f)
def get_preamble() -> str:
"""Returns the cleaned up preamble for the config file."""
preamble = """ preamble = """
# Sample YAML file for configuration. # Sample YAML file for configuration.
# Comment and uncomment values as needed. # Comment and uncomment values as needed.
@@ -184,61 +199,43 @@ def generate_config_file(
# Unless specified in the comments, DO NOT put these options in quotes! # Unless specified in the comments, DO NOT put these options in quotes!
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.\n # You can use https://www.yamllint.com/ if you want to check your YAML formatting.\n
""" """
return dedent(preamble).lstrip()
# Trim and cleanup preamble
yaml = dedent(preamble).lstrip()
schema = unwrap(model, TabbyConfigModel()) # Function to convert pydantic model to dict with field descriptions as comments
def pydantic_model_to_yaml(model: BaseModel) -> CommentedMap:
"""
Recursively converts a Pydantic model into a CommentedMap,
with descriptions as comments in YAML.
"""
# Create a CommentedMap to hold the output data
yaml_data = CommentedMap()
# TODO: Make the disordered iteration look cleaner # Loop through all fields in the model
iter_once = False for field_name, field_info in model.model_fields.items():
for field, field_data in schema.model_fields.items(): value = getattr(model, field_name)
# Fetch from the existing model class if it's passed
# Probably can use this on schema too, but play it safe # If the field is another Pydantic model
if model and hasattr(model, field): if isinstance(value, BaseModel):
subfield_model = getattr(model, field) yaml_data[field_name] = pydantic_model_to_yaml(value)
# If the field is a list of Pydantic models
elif (
isinstance(value, list)
and len(value) > 0
and isinstance(value[0], BaseModel)
):
yaml_list = CommentedSeq()
for item in value:
yaml_list.append(pydantic_model_to_yaml(item))
yaml_data[field_name] = yaml_list
# Otherwise, just assign the value
else: else:
subfield_model = field_data.default_factory() yaml_data[field_name] = value
if not subfield_model._metadata.include_in_config: # Add field description as a comment if available
continue if field_info.description:
yaml_data.yaml_set_comment_before_after_key(
field_name, before=field_info.description
)
# Since the list is out of order with the length return yaml_data
# Add newlines from the beginning once one iteration finishes
# This is a sanity check for formatting
if iter_once:
yaml += "\n"
else:
iter_once = True
for line in getdoc(subfield_model).splitlines():
yaml += f"# {line}\n"
yaml += f"{field}:\n"
sub_iter_once = False
for subfield, subfield_data in subfield_model.model_fields.items():
# Same logic as iter_once
if sub_iter_once:
yaml += "\n"
else:
sub_iter_once = True
# If a value already exists, use it
if hasattr(subfield_model, subfield):
value = getattr(subfield_model, subfield)
elif subfield_data.default_factory:
value = subfield_data.default_factory()
else:
value = subfield_data.default
value = value if value is not None else ""
value = value if value is not PydanticUndefined else ""
for line in subfield_data.description.splitlines():
yaml += f"{' ' * indentation}# {line}\n"
yaml += f"{' ' * indentation}{subfield}: {value}\n"
with open(filename, "w") as f:
f.write(yaml)

View File

@@ -1,10 +1,12 @@
"""Common utility functions""" """Common utility functions"""
from types import NoneType from types import NoneType
from typing import Type, Union, get_args, get_origin from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar
T = TypeVar("T")
def unwrap(wrapped, default=None): def unwrap(wrapped: Optional[T], default: T = None) -> T:
"""Unwrap function for Optionals.""" """Unwrap function for Optionals."""
if wrapped is None: if wrapped is None:
return default return default
@@ -17,13 +19,13 @@ def coalesce(*args):
return next((arg for arg in args if arg is not None), None) return next((arg for arg in args if arg is not None), None)
def prune_dict(input_dict): def prune_dict(input_dict: Dict) -> Dict:
"""Trim out instances of None from a dictionary.""" """Trim out instances of None from a dictionary."""
return {k: v for k, v in input_dict.items() if v is not None} return {k: v for k, v in input_dict.items() if v is not None}
def merge_dict(dict1, dict2): def merge_dict(dict1: Dict, dict2: Dict) -> Dict:
"""Merge 2 dictionaries""" """Merge 2 dictionaries"""
for key, value in dict2.items(): for key, value in dict2.items():
if isinstance(value, dict) and key in dict1 and isinstance(dict1[key], dict): if isinstance(value, dict) and key in dict1 and isinstance(dict1[key], dict):
@@ -33,7 +35,7 @@ def merge_dict(dict1, dict2):
return dict1 return dict1
def merge_dicts(*dicts): def merge_dicts(*dicts: Dict) -> Dict:
"""Merge an arbitrary amount of dictionaries""" """Merge an arbitrary amount of dictionaries"""
result = {} result = {}
for dictionary in dicts: for dictionary in dicts:

View File

@@ -18,7 +18,7 @@ requires-python = ">=3.10"
dependencies = [ dependencies = [
"fastapi-slim >= 0.110.0", "fastapi-slim >= 0.110.0",
"pydantic >= 2.0.0", "pydantic >= 2.0.0",
"PyYAML", "ruamel.yaml",
"rich", "rich",
"uvicorn >= 0.28.1", "uvicorn >= 0.28.1",
"jinja2 >= 3.0.0", "jinja2 >= 3.0.0",