From 948fcb7f5b8eaf829fcab6bb947615e6e78e7b4d Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Wed, 18 Sep 2024 01:06:34 +0100 Subject: [PATCH] migrate to ruamel.yaml --- common/tabby_config.py | 125 ++++++++++++++++++++--------------------- common/utils.py | 12 ++-- pyproject.toml | 2 +- 3 files changed, 69 insertions(+), 70 deletions(-) diff --git a/common/tabby_config.py b/common/tabby_config.py index 5491b9c..283dd17 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -1,14 +1,19 @@ -import yaml import pathlib from inspect import getdoc -from pydantic_core import PydanticUndefined -from loguru import logger -from textwrap import dedent -from typing import Optional from os import getenv +from textwrap import dedent +from typing import Any, Optional -from common.utils import unwrap, merge_dicts -from common.config_models import BaseConfigModel, TabbyConfigModel +from loguru import logger +from pydantic import BaseModel +from pydantic_core import PydanticUndefined +from ruamel.yaml import YAML +from ruamel.yaml.comments import CommentedMap, CommentedSeq + +from common.config_models import TabbyConfigModel +from common.utils import merge_dicts, unwrap + +yaml = YAML() class TabbyConfig(TabbyConfigModel): @@ -57,7 +62,7 @@ class TabbyConfig(TabbyConfigModel): # try loading from file try: with open(str(config_path.resolve()), "r", encoding="utf8") as config_file: - cfg = yaml.safe_load(config_file) + cfg = yaml.load(config_file) # NOTE: Remove migration wrapper after a period of time # load legacy config files @@ -130,7 +135,7 @@ class TabbyConfig(TabbyConfigModel): """loads config from the provided arguments""" config = {} - config_override = unwrap(args.get("options", {}).get("config")) + config_override = args.get("options", {}).get("config", None) if config_override: logger.info("Config file override detected in args.") config = self._from_file(pathlib.Path(config_override)) @@ -166,15 +171,25 @@ class TabbyConfig(TabbyConfigModel): config: TabbyConfig = TabbyConfig() -# TODO: Possibly switch to ruamel.yaml for a more native implementation def generate_config_file( - model: BaseConfigModel = None, + model: BaseModel = None, filename: str = "config_sample.yml", indentation: int = 2, ) -> None: """Creates a config.yml file from Pydantic models.""" - # Add a cleaned up preamble + schema = unwrap(model, TabbyConfigModel()) + preamble = get_preamble() + + yaml_content = pydantic_model_to_yaml(schema) + + with open(filename, "w") as f: + f.write(preamble) + yaml.dump(yaml_content, f) + + +def get_preamble() -> str: + """Returns the cleaned up preamble for the config file.""" preamble = """ # Sample YAML file for configuration. # Comment and uncomment values as needed. @@ -184,61 +199,43 @@ def generate_config_file( # Unless specified in the comments, DO NOT put these options in quotes! # You can use https://www.yamllint.com/ if you want to check your YAML formatting.\n """ + return dedent(preamble).lstrip() - # Trim and cleanup preamble - yaml = dedent(preamble).lstrip() - schema = unwrap(model, TabbyConfigModel()) +# Function to convert pydantic model to dict with field descriptions as comments +def pydantic_model_to_yaml(model: BaseModel) -> CommentedMap: + """ + Recursively converts a Pydantic model into a CommentedMap, + with descriptions as comments in YAML. + """ + # Create a CommentedMap to hold the output data + yaml_data = CommentedMap() - # TODO: Make the disordered iteration look cleaner - iter_once = False - for field, field_data in schema.model_fields.items(): - # Fetch from the existing model class if it's passed - # Probably can use this on schema too, but play it safe - if model and hasattr(model, field): - subfield_model = getattr(model, field) + # Loop through all fields in the model + for field_name, field_info in model.model_fields.items(): + value = getattr(model, field_name) + + # If the field is another Pydantic model + if isinstance(value, BaseModel): + yaml_data[field_name] = pydantic_model_to_yaml(value) + # If the field is a list of Pydantic models + elif ( + isinstance(value, list) + and len(value) > 0 + and isinstance(value[0], BaseModel) + ): + yaml_list = CommentedSeq() + for item in value: + yaml_list.append(pydantic_model_to_yaml(item)) + yaml_data[field_name] = yaml_list + # Otherwise, just assign the value else: - subfield_model = field_data.default_factory() + yaml_data[field_name] = value - if not subfield_model._metadata.include_in_config: - continue + # Add field description as a comment if available + if field_info.description: + yaml_data.yaml_set_comment_before_after_key( + field_name, before=field_info.description + ) - # Since the list is out of order with the length - # Add newlines from the beginning once one iteration finishes - # This is a sanity check for formatting - if iter_once: - yaml += "\n" - else: - iter_once = True - - for line in getdoc(subfield_model).splitlines(): - yaml += f"# {line}\n" - - yaml += f"{field}:\n" - - sub_iter_once = False - for subfield, subfield_data in subfield_model.model_fields.items(): - # Same logic as iter_once - if sub_iter_once: - yaml += "\n" - else: - sub_iter_once = True - - # If a value already exists, use it - if hasattr(subfield_model, subfield): - value = getattr(subfield_model, subfield) - elif subfield_data.default_factory: - value = subfield_data.default_factory() - else: - value = subfield_data.default - - value = value if value is not None else "" - value = value if value is not PydanticUndefined else "" - - for line in subfield_data.description.splitlines(): - yaml += f"{' ' * indentation}# {line}\n" - - yaml += f"{' ' * indentation}{subfield}: {value}\n" - - with open(filename, "w") as f: - f.write(yaml) + return yaml_data diff --git a/common/utils.py b/common/utils.py index f8b4671..dfa7e92 100644 --- a/common/utils.py +++ b/common/utils.py @@ -1,10 +1,12 @@ """Common utility functions""" from types import NoneType -from typing import Type, Union, get_args, get_origin +from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar + +T = TypeVar("T") -def unwrap(wrapped, default=None): +def unwrap(wrapped: Optional[T], default: T = None) -> T: """Unwrap function for Optionals.""" if wrapped is None: return default @@ -17,13 +19,13 @@ def coalesce(*args): return next((arg for arg in args if arg is not None), None) -def prune_dict(input_dict): +def prune_dict(input_dict: Dict) -> Dict: """Trim out instances of None from a dictionary.""" return {k: v for k, v in input_dict.items() if v is not None} -def merge_dict(dict1, dict2): +def merge_dict(dict1: Dict, dict2: Dict) -> Dict: """Merge 2 dictionaries""" for key, value in dict2.items(): if isinstance(value, dict) and key in dict1 and isinstance(dict1[key], dict): @@ -33,7 +35,7 @@ def merge_dict(dict1, dict2): return dict1 -def merge_dicts(*dicts): +def merge_dicts(*dicts: Dict) -> Dict: """Merge an arbitrary amount of dictionaries""" result = {} for dictionary in dicts: diff --git a/pyproject.toml b/pyproject.toml index ad6f945..3289288 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ requires-python = ">=3.10" dependencies = [ "fastapi-slim >= 0.110.0", "pydantic >= 2.0.0", - "PyYAML", + "ruamel.yaml", "rich", "uvicorn >= 0.28.1", "jinja2 >= 3.0.0",