fix model names

This commit is contained in:
TerminalMan
2024-09-12 17:00:07 +01:00
parent 05f1c3e293
commit 8b48f00271
2 changed files with 38 additions and 35 deletions

View File

@@ -4,13 +4,13 @@ from typing import List, Optional, Union
from common.utils import unwrap
class config_config_model(BaseModel):
class ConfigConfig(BaseModel):
config: Optional[str] = Field(
None, description=("Path to an overriding config.yml file")
)
class network_config_model(BaseModel):
class NetworkConfig(BaseModel):
host: Optional[str] = Field("127.0.0.1", description=("The IP to host on"))
port: Optional[int] = Field(5000, description=("The port to host on"))
disable_auth: Optional[bool] = Field(
@@ -28,7 +28,7 @@ class network_config_model(BaseModel):
)
class logging_config_model(BaseModel):
class LoggingConfig(BaseModel):
log_prompt: Optional[bool] = Field(False, description=("Enable prompt logging"))
log_generation_params: Optional[bool] = Field(
False, description=("Enable generation parameter logging")
@@ -36,7 +36,7 @@ class logging_config_model(BaseModel):
log_requests: Optional[bool] = Field(False, description=("Enable request logging"))
class model_config_model(BaseModel):
class ModelConfig(BaseModel):
model_dir: str = Field(
"models",
description=(
@@ -171,8 +171,10 @@ class model_config_model(BaseModel):
),
)
model_config = ConfigDict(protected_namespaces=())
class draft_model_config_model(BaseModel):
class DraftModelConfig(BaseModel):
draft_model_dir: Optional[str] = Field(
"models",
description=(
@@ -209,18 +211,18 @@ class draft_model_config_model(BaseModel):
)
class lora_instance_model(BaseModel):
class LoraInstanceModel(BaseModel):
name: str = Field(..., description=("Name of the LoRA model"))
scaling: float = Field(
1.0, description=("Scaling factor for the LoRA model (default: 1.0)")
)
class lora_config_model(BaseModel):
class LoraConfig(BaseModel):
lora_dir: Optional[str] = Field(
"loras", description=("Directory to look for LoRAs (default: 'loras')")
)
loras: Optional[List[lora_instance_model]] = Field(
loras: Optional[List[LoraInstanceModel]] = Field(
None,
description=(
"List of LoRAs to load and associated scaling factors (default scaling:"
@@ -229,13 +231,13 @@ class lora_config_model(BaseModel):
)
class sampling_config_model(BaseModel):
class SamplingConfig(BaseModel):
override_preset: Optional[str] = Field(
None, description=("Select a sampler override preset")
)
class developer_config_model(BaseModel):
class DeveloperConfig(BaseModel):
unsafe_launch: Optional[bool] = Field(
False, description=("Skip Exllamav2 version check")
)
@@ -257,7 +259,7 @@ class developer_config_model(BaseModel):
)
class embeddings_config_model(BaseModel):
class EmbeddingsConfig(BaseModel):
embedding_model_dir: Optional[str] = Field(
"models",
description=(
@@ -276,18 +278,20 @@ class embeddings_config_model(BaseModel):
)
class tabby_config_model(BaseModel):
config: config_config_model = Field(default_factory=config_config_model)
network: network_config_model = Field(default_factory=network_config_model)
logging: logging_config_model = Field(default_factory=logging_config_model)
model: model_config_model = Field(default_factory=model_config_model)
draft_model: draft_model_config_model = Field(
default_factory=draft_model_config_model
class TabbyConfigModel(BaseModel):
config: ConfigConfig = Field(default_factory=ConfigConfig.model_construct)
network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct)
logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct)
model: ModelConfig = Field(default_factory=ModelConfig.model_construct)
draft_model: DraftModelConfig = Field(
default_factory=DraftModelConfig.model_construct
)
lora: LoraConfig = Field(default_factory=LoraConfig.model_construct)
sampling: SamplingConfig = Field(default_factory=SamplingConfig.model_construct)
developer: DeveloperConfig = Field(default_factory=DeveloperConfig.model_construct)
embeddings: EmbeddingsConfig = Field(
default_factory=EmbeddingsConfig.model_construct
)
lora: lora_config_model = Field(default_factory=lora_config_model)
sampling: sampling_config_model = Field(default_factory=sampling_config_model)
developer: developer_config_model = Field(default_factory=developer_config_model)
embeddings: embeddings_config_model = Field(default_factory=embeddings_config_model)
@model_validator(mode="before")
def set_defaults(cls, values):
@@ -297,11 +301,11 @@ class tabby_config_model(BaseModel):
values[field_name] = cls.__annotations__[field_name](**default_instance)
return values
model_config = ConfigDict(validate_assignment=True)
model_config = ConfigDict(validate_assignment=True, protected_namespaces=())
def generate_config_file(filename="config_sample.yml", indentation=2):
schema = tabby_config_model.model_json_schema()
schema = TabbyConfigModel.model_json_schema()
def dump_def(id: str, indent=2):
yaml = ""