mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-27 09:41:54 +00:00
fix model names
This commit is contained in:
@@ -4,13 +4,13 @@ from typing import List, Optional, Union
|
||||
from common.utils import unwrap
|
||||
|
||||
|
||||
class config_config_model(BaseModel):
|
||||
class ConfigConfig(BaseModel):
|
||||
config: Optional[str] = Field(
|
||||
None, description=("Path to an overriding config.yml file")
|
||||
)
|
||||
|
||||
|
||||
class network_config_model(BaseModel):
|
||||
class NetworkConfig(BaseModel):
|
||||
host: Optional[str] = Field("127.0.0.1", description=("The IP to host on"))
|
||||
port: Optional[int] = Field(5000, description=("The port to host on"))
|
||||
disable_auth: Optional[bool] = Field(
|
||||
@@ -28,7 +28,7 @@ class network_config_model(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class logging_config_model(BaseModel):
|
||||
class LoggingConfig(BaseModel):
|
||||
log_prompt: Optional[bool] = Field(False, description=("Enable prompt logging"))
|
||||
log_generation_params: Optional[bool] = Field(
|
||||
False, description=("Enable generation parameter logging")
|
||||
@@ -36,7 +36,7 @@ class logging_config_model(BaseModel):
|
||||
log_requests: Optional[bool] = Field(False, description=("Enable request logging"))
|
||||
|
||||
|
||||
class model_config_model(BaseModel):
|
||||
class ModelConfig(BaseModel):
|
||||
model_dir: str = Field(
|
||||
"models",
|
||||
description=(
|
||||
@@ -171,8 +171,10 @@ class model_config_model(BaseModel):
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
class draft_model_config_model(BaseModel):
|
||||
|
||||
class DraftModelConfig(BaseModel):
|
||||
draft_model_dir: Optional[str] = Field(
|
||||
"models",
|
||||
description=(
|
||||
@@ -209,18 +211,18 @@ class draft_model_config_model(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class lora_instance_model(BaseModel):
|
||||
class LoraInstanceModel(BaseModel):
|
||||
name: str = Field(..., description=("Name of the LoRA model"))
|
||||
scaling: float = Field(
|
||||
1.0, description=("Scaling factor for the LoRA model (default: 1.0)")
|
||||
)
|
||||
|
||||
|
||||
class lora_config_model(BaseModel):
|
||||
class LoraConfig(BaseModel):
|
||||
lora_dir: Optional[str] = Field(
|
||||
"loras", description=("Directory to look for LoRAs (default: 'loras')")
|
||||
)
|
||||
loras: Optional[List[lora_instance_model]] = Field(
|
||||
loras: Optional[List[LoraInstanceModel]] = Field(
|
||||
None,
|
||||
description=(
|
||||
"List of LoRAs to load and associated scaling factors (default scaling:"
|
||||
@@ -229,13 +231,13 @@ class lora_config_model(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class sampling_config_model(BaseModel):
|
||||
class SamplingConfig(BaseModel):
|
||||
override_preset: Optional[str] = Field(
|
||||
None, description=("Select a sampler override preset")
|
||||
)
|
||||
|
||||
|
||||
class developer_config_model(BaseModel):
|
||||
class DeveloperConfig(BaseModel):
|
||||
unsafe_launch: Optional[bool] = Field(
|
||||
False, description=("Skip Exllamav2 version check")
|
||||
)
|
||||
@@ -257,7 +259,7 @@ class developer_config_model(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class embeddings_config_model(BaseModel):
|
||||
class EmbeddingsConfig(BaseModel):
|
||||
embedding_model_dir: Optional[str] = Field(
|
||||
"models",
|
||||
description=(
|
||||
@@ -276,18 +278,20 @@ class embeddings_config_model(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class tabby_config_model(BaseModel):
|
||||
config: config_config_model = Field(default_factory=config_config_model)
|
||||
network: network_config_model = Field(default_factory=network_config_model)
|
||||
logging: logging_config_model = Field(default_factory=logging_config_model)
|
||||
model: model_config_model = Field(default_factory=model_config_model)
|
||||
draft_model: draft_model_config_model = Field(
|
||||
default_factory=draft_model_config_model
|
||||
class TabbyConfigModel(BaseModel):
|
||||
config: ConfigConfig = Field(default_factory=ConfigConfig.model_construct)
|
||||
network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct)
|
||||
logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct)
|
||||
model: ModelConfig = Field(default_factory=ModelConfig.model_construct)
|
||||
draft_model: DraftModelConfig = Field(
|
||||
default_factory=DraftModelConfig.model_construct
|
||||
)
|
||||
lora: LoraConfig = Field(default_factory=LoraConfig.model_construct)
|
||||
sampling: SamplingConfig = Field(default_factory=SamplingConfig.model_construct)
|
||||
developer: DeveloperConfig = Field(default_factory=DeveloperConfig.model_construct)
|
||||
embeddings: EmbeddingsConfig = Field(
|
||||
default_factory=EmbeddingsConfig.model_construct
|
||||
)
|
||||
lora: lora_config_model = Field(default_factory=lora_config_model)
|
||||
sampling: sampling_config_model = Field(default_factory=sampling_config_model)
|
||||
developer: developer_config_model = Field(default_factory=developer_config_model)
|
||||
embeddings: embeddings_config_model = Field(default_factory=embeddings_config_model)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def set_defaults(cls, values):
|
||||
@@ -297,11 +301,11 @@ class tabby_config_model(BaseModel):
|
||||
values[field_name] = cls.__annotations__[field_name](**default_instance)
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True)
|
||||
model_config = ConfigDict(validate_assignment=True, protected_namespaces=())
|
||||
|
||||
|
||||
def generate_config_file(filename="config_sample.yml", indentation=2):
|
||||
schema = tabby_config_model.model_json_schema()
|
||||
schema = TabbyConfigModel.model_json_schema()
|
||||
|
||||
def dump_def(id: str, indent=2):
|
||||
yaml = ""
|
||||
|
||||
Reference in New Issue
Block a user