mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-30 03:01:44 +00:00
patch pydantic config into old config
- convert pydantic to dict to avoid errors with current files - fix formatting
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import get_origin, get_args, Optional, Union, List
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from common.tabby_config import config
|
from common.tabby_config import config
|
||||||
|
|
||||||
|
|
||||||
def str_to_bool(value):
|
def str_to_bool(value):
|
||||||
"""Converts a string into a boolean value"""
|
"""Converts a string into a boolean value"""
|
||||||
|
|
||||||
@@ -40,34 +41,43 @@ def init_argparser():
|
|||||||
for field_name, field_type in config.__annotations__.items():
|
for field_name, field_type in config.__annotations__.items():
|
||||||
# Get the sub-model type (e.g., ModelA, ModelB)
|
# Get the sub-model type (e.g., ModelA, ModelB)
|
||||||
sub_model = field_type.__base__
|
sub_model = field_type.__base__
|
||||||
|
|
||||||
# Create argument group for the sub-model
|
# Create argument group for the sub-model
|
||||||
group = parser.add_argument_group(field_name, description=f"Arguments for {field_name}")
|
group = parser.add_argument_group(
|
||||||
|
field_name, description=f"Arguments for {field_name}"
|
||||||
|
)
|
||||||
|
|
||||||
# Loop through each field in the sub-model (e.g., ModelA, ModelB)
|
# Loop through each field in the sub-model (e.g., ModelA, ModelB)
|
||||||
for sub_field_name, sub_field_type in field_type.__annotations__.items():
|
for sub_field_name, sub_field_type in field_type.__annotations__.items():
|
||||||
field = field_type.__fields__[sub_field_name]
|
field = field_type.__fields__[sub_field_name]
|
||||||
help_text = field.description if field.description else "No description available"
|
help_text = (
|
||||||
|
field.description if field.description else "No description available"
|
||||||
|
)
|
||||||
|
|
||||||
# Handle Optional types or other generic types
|
# Handle Optional types or other generic types
|
||||||
origin = get_origin(sub_field_type)
|
origin = get_origin(sub_field_type)
|
||||||
if origin is Union: # Check if the type is Union (which includes Optional)
|
if origin is Union: # Check if the type is Union (which includes Optional)
|
||||||
sub_field_type = next(t for t in get_args(sub_field_type) if t is not type(None))
|
sub_field_type = next(
|
||||||
elif origin is List : sub_field_type = get_args(sub_field_type)[0]
|
t for t in get_args(sub_field_type) if t is not type(None)
|
||||||
|
)
|
||||||
|
elif origin is List:
|
||||||
|
sub_field_type = get_args(sub_field_type)[0]
|
||||||
|
|
||||||
# Map Pydantic types to argparse types
|
# Map Pydantic types to argparse types
|
||||||
print(sub_field_type, type(sub_field_type))
|
print(sub_field_type, type(sub_field_type))
|
||||||
if isinstance(sub_field_type, type) and issubclass(sub_field_type, (int, float, str, bool)):
|
if isinstance(sub_field_type, type) and issubclass(
|
||||||
|
sub_field_type, (int, float, str, bool)
|
||||||
|
):
|
||||||
arg_type = sub_field_type
|
arg_type = sub_field_type
|
||||||
else:
|
else:
|
||||||
arg_type = str # Default to string for unknown types
|
arg_type = str # Default to string for unknown types
|
||||||
|
|
||||||
# Add the argument for each field in the sub-model
|
# Add the argument for each field in the sub-model
|
||||||
group.add_argument(f"--{sub_field_name}", type=arg_type, help=help_text)
|
group.add_argument(f"--{sub_field_name}", type=arg_type, help=help_text)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||||
"""Broad conversion of surface level arg groups to dictionaries"""
|
"""Broad conversion of surface level arg groups to dictionaries"""
|
||||||
|
|
||||||
@@ -81,4 +91,4 @@ def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentPars
|
|||||||
|
|
||||||
arg_groups[group.title] = group_dict
|
arg_groups[group.title] = group_dict
|
||||||
|
|
||||||
return arg_groups
|
return arg_groups
|
||||||
|
|||||||
9
main.py
9
main.py
@@ -69,12 +69,14 @@ async def entrypoint_async():
|
|||||||
model_path = pathlib.Path(config.model.model_dir)
|
model_path = pathlib.Path(config.model.model_dir)
|
||||||
model_path = model_path / model_name
|
model_path = model_path / model_name
|
||||||
|
|
||||||
await model.load_model(model_path.resolve(), **config.model)
|
# TODO: remove model_dump()
|
||||||
|
await model.load_model(model_path.resolve(), **config.model.model_dump())
|
||||||
|
|
||||||
# Load loras after loading the model
|
# Load loras after loading the model
|
||||||
if config.lora.loras:
|
if config.lora.loras:
|
||||||
lora_dir = pathlib.Path(config.lora.lora_dir)
|
lora_dir = pathlib.Path(config.lora.lora_dir)
|
||||||
await model.container.load_loras(lora_dir.resolve(), **config.lora)
|
# TODO: remove model_dump()
|
||||||
|
await model.container.load_loras(lora_dir.resolve(), **config.lora.model_dump())
|
||||||
|
|
||||||
# If an initial embedding model name is specified, create a separate container
|
# If an initial embedding model name is specified, create a separate container
|
||||||
# and load the model
|
# and load the model
|
||||||
@@ -84,7 +86,8 @@ async def entrypoint_async():
|
|||||||
embedding_model_path = embedding_model_path / embedding_model_name
|
embedding_model_path = embedding_model_path / embedding_model_name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await model.load_embedding_model(embedding_model_path, **config.embeddings)
|
# TODO: remove model_dump()
|
||||||
|
await model.load_embedding_model(embedding_model_path, **config.embeddings.model_dump())
|
||||||
except ImportError as ex:
|
except ImportError as ex:
|
||||||
logger.error(ex.msg)
|
logger.error(ex.msg)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user