mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
Model: Make model params return a model card
The model card is a unified structure for sharing model params. Rather than kwargs, use this instead. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.sampling import BaseSamplerRequest
|
||||
from common.templating import PromptTemplate
|
||||
from common.transformers_utils import GenerationConfig
|
||||
from endpoints.core.types.model import ModelCard
|
||||
|
||||
|
||||
class BaseModelContainer(abc.ABC):
|
||||
@@ -189,7 +190,7 @@ class BaseModelContainer(abc.ABC):
|
||||
|
||||
# TODO: Replace by yielding a model card
|
||||
@abc.abstractmethod
|
||||
def get_model_parameters(self) -> Dict[str, Any]:
|
||||
def model_info(self) -> ModelCard:
|
||||
"""
|
||||
Returns a dictionary of the current model's configuration parameters.
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ from common.sampling import BaseSamplerRequest
|
||||
from common.templating import PromptTemplate, find_prompt_template
|
||||
from common.transformers_utils import GenerationConfig
|
||||
from common.utils import calculate_rope_alpha, coalesce, unwrap
|
||||
from endpoints.core.types.model import ModelCard, ModelCardParameters
|
||||
|
||||
|
||||
class ExllamaV2Container(BaseModelContainer):
|
||||
@@ -379,35 +380,43 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
# Return the created instance
|
||||
return self
|
||||
|
||||
def get_model_parameters(self):
|
||||
model_params = {
|
||||
"name": self.model_dir.name,
|
||||
"rope_scale": self.config.scale_pos_emb,
|
||||
"rope_alpha": self.config.scale_alpha_value,
|
||||
"max_seq_len": self.config.max_seq_len,
|
||||
"max_batch_size": self.max_batch_size,
|
||||
"cache_size": self.cache_size,
|
||||
"cache_mode": self.cache_mode,
|
||||
"chunk_size": self.config.max_input_len,
|
||||
"use_vision": self.use_vision,
|
||||
}
|
||||
def model_info(self):
|
||||
draft_model_card: ModelCard = None
|
||||
if self.draft_config:
|
||||
draft_model_params = ModelCardParameters(
|
||||
max_seq_len=self.draft_config.max_seq_len,
|
||||
rope_scale=self.draft_config.scale_pos_emb,
|
||||
rope_alpha=self.draft_config.scale_alpha_value,
|
||||
cache_mode=self.draft_cache_mode,
|
||||
)
|
||||
|
||||
draft_model_card = ModelCard(
|
||||
id=self.draft_model_dir.name,
|
||||
parameters=draft_model_params,
|
||||
)
|
||||
|
||||
model_params = ModelCardParameters(
|
||||
max_seq_len=self.config.max_seq_len,
|
||||
cache_size=self.cache_size,
|
||||
rope_scale=self.config.scale_pos_emb,
|
||||
rope_alpha=self.config.scale_alpha_value,
|
||||
max_batch_size=self.max_batch_size,
|
||||
cache_mode=self.cache_mode,
|
||||
chunk_size=self.config.max_input_len,
|
||||
use_vision=self.use_vision,
|
||||
draft=draft_model_card,
|
||||
)
|
||||
|
||||
if self.prompt_template:
|
||||
model_params["prompt_template"] = self.prompt_template.name
|
||||
model_params["prompt_template_content"] = self.prompt_template.raw_template
|
||||
model_params.prompt_template = self.prompt_template.name
|
||||
model_params.prompt_template_content = self.prompt_template.raw_template
|
||||
|
||||
if self.draft_config:
|
||||
draft_model_params = {
|
||||
"name": self.draft_model_dir.name,
|
||||
"rope_scale": self.draft_config.scale_pos_emb,
|
||||
"rope_alpha": self.draft_config.scale_alpha_value,
|
||||
"max_seq_len": self.draft_config.max_seq_len,
|
||||
"cache_mode": self.draft_cache_mode,
|
||||
}
|
||||
model_card = ModelCard(
|
||||
id=self.model_dir.name,
|
||||
parameters=model_params,
|
||||
)
|
||||
|
||||
model_params["draft"] = draft_model_params
|
||||
|
||||
return model_params
|
||||
return model_card
|
||||
|
||||
async def wait_for_jobs(self, skip_wait: bool = False):
|
||||
"""Polling mechanism to wait for pending generation jobs."""
|
||||
|
||||
Reference in New Issue
Block a user