Merge pull request #322 from theroyallab/model-rewrite

Model rewrite
This commit is contained in:
Brian
2025-04-26 02:15:48 -04:00
committed by GitHub
15 changed files with 758 additions and 535 deletions

View File

@@ -0,0 +1,244 @@
import abc
import asyncio
import pathlib
from loguru import logger
from typing import (
Any,
AsyncIterator,
Dict,
List,
Optional,
)
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate
from common.transformers_utils import GenerationConfig
from endpoints.core.types.model import ModelCard
class BaseModelContainer(abc.ABC):
"""Abstract base class for model containers."""
# Exposed model information
model_dir: pathlib.Path = pathlib.Path("models")
prompt_template: Optional[PromptTemplate] = None
generation_config: Optional[GenerationConfig] = None
# Load synchronization
# The bool is a master switch for accepting requests
# The lock keeps load tasks sequential
# The condition notifies any waiting tasks
active_job_ids: Dict[str, Any] = {}
loaded: bool = False
load_lock: asyncio.Lock
load_condition: asyncio.Condition
# Required methods
@classmethod
@abc.abstractmethod
async def create(cls, model_directory: pathlib.Path, **kwargs):
"""
Asynchronously creates and initializes a model container instance.
Args:
model_directory: Path to the model files.
**kwargs: Backend-specific configuration options.
Returns:
An instance of the implementing class.
"""
pass
@abc.abstractmethod
async def load(self, progress_callback=None, **kwargs):
"""
Loads the model into memory.
Args:
progress_callback: Optional callback for progress updates.
**kwargs: Additional loading options.
"""
pass
# NOTE: Might be an optional method
@abc.abstractmethod
async def load_gen(self, progress_callback=None, **kwargs) -> AsyncIterator[Any]:
"""
Loads the model into memory, yielding progress updates.
Args:
progress_callback: Optional callback for progress updates.
**kwargs: Additional loading options.
Yields:
Progress updates
"""
if False:
yield
@abc.abstractmethod
async def unload(self, loras_only: bool = False, **kwargs):
"""
Unloads the model and associated resources from memory.
Args:
loras_only: If True, only unload LoRAs.
**kwargs: Additional unloading options (e.g., shutdown).
"""
pass
@abc.abstractmethod
def encode_tokens(self, text: str, **kwargs) -> List[int]:
"""
Encodes a string of text into a list of token IDs.
Args:
text: The input text string.
**kwargs: Backend-specific encoding options (e.g., add_bos_token).
Returns:
A list of integer token IDs.
"""
pass
@abc.abstractmethod
def decode_tokens(self, ids: List[int], **kwargs) -> str:
"""
Decodes a list of token IDs back into a string.
Args:
ids: A list of integer token IDs.
**kwargs: Backend-specific decoding options (e.g., decode_special_tokens).
Returns:
The decoded text string.
"""
pass
@abc.abstractmethod
def get_special_tokens(self, **kwargs) -> Dict[str, Any]:
"""
Gets special tokens used by the model/tokenizer.
Args:
**kwargs: Options like add_bos_token, ban_eos_token.
Returns:
A dictionary mapping special token names (e.g., 'bos_token', 'eos_token')
to their string or ID representation.
"""
pass
@abc.abstractmethod
async def generate(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
) -> Dict[str, Any]:
"""
Generates a complete response for a given prompt and parameters.
Args:
request_id: Unique identifier for the generation request.
prompt: The input prompt string.
params: Sampling and generation parameters.
abort_event: An asyncio Event to signal cancellation.
mm_embeddings: Optional multimodal embeddings.
Returns:
A dictionary containing the generation info
"""
pass
@abc.abstractmethod
async def stream_generate(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""
Generates a response iteratively (streaming) for a given prompt.
Args:
request_id: Unique identifier for the generation request.
prompt: The input prompt string.
params: Sampling and generation parameters.
abort_event: An asyncio Event to signal cancellation.
mm_embeddings: Optional multimodal embeddings.
Yields:
Generation chunks
"""
if False:
yield
@abc.abstractmethod
def model_info(self) -> ModelCard:
"""
Returns a dictionary of the current model's configuration parameters.
Returns:
Model parameters provided by the backend
"""
pass
@abc.abstractmethod
async def wait_for_jobs(self, skip_wait: bool = False):
"""
Waits for any active generation jobs to complete.
Args:
skip_wait: If True, cancel jobs immediately instead of waiting.
"""
pass
# Optional methods
async def load_loras(
self, lora_directory: pathlib.Path, **kwargs
) -> Dict[str, List[str]]:
"""
Loads LoRA adapters. Base implementation does nothing or raises error.
Args:
lora_directory: Path to the directory containing LoRA files.
**kwargs: LoRA configuration (e.g., list of loras, scaling).
Returns:
A dictionary indicating success/failure for each LoRA.
"""
logger.warning("LoRA loading not implemented for this backend.") # type: ignore
return {
"success": [],
"failure": [
lora.get("name", "unknown") for lora in kwargs.get("loras", [])
],
}
def get_loras(self) -> List[Any]:
"""
Gets the currently loaded LoRA adapters. Base implementation returns empty list.
Returns:
A list representing the loaded LoRAs (backend-specific format).
"""
return []

File diff suppressed because it is too large Load Diff

View File

@@ -14,8 +14,7 @@ if dependencies.extras:
class InfinityContainer: class InfinityContainer:
model_dir: pathlib.Path model_dir: pathlib.Path
model_is_loading: bool = False loaded: bool = False
model_loaded: bool = False
# Use a runtime type hint here # Use a runtime type hint here
engine: Optional["AsyncEmbeddingEngine"] = None engine: Optional["AsyncEmbeddingEngine"] = None
@@ -24,8 +23,6 @@ class InfinityContainer:
self.model_dir = model_directory self.model_dir = model_directory
async def load(self, **kwargs): async def load(self, **kwargs):
self.model_is_loading = True
# Use cpu by default # Use cpu by default
device = unwrap(kwargs.get("embeddings_device"), "cpu") device = unwrap(kwargs.get("embeddings_device"), "cpu")
@@ -40,7 +37,7 @@ class InfinityContainer:
self.engine = AsyncEmbeddingEngine.from_args(engine_args) self.engine = AsyncEmbeddingEngine.from_args(engine_args)
await self.engine.astart() await self.engine.astart()
self.model_loaded = True self.loaded = True
logger.info("Embedding model successfully loaded.") logger.info("Embedding model successfully loaded.")
async def unload(self): async def unload(self):

View File

@@ -4,24 +4,29 @@ Manages the storage and utility of model containers.
Containers exist as a common interface for backends. Containers exist as a common interface for backends.
""" """
import aiofiles
import pathlib import pathlib
from enum import Enum from enum import Enum
from fastapi import HTTPException from fastapi import HTTPException
from loguru import logger from loguru import logger
from ruamel.yaml import YAML
from typing import Optional from typing import Optional
from backends.base_model_container import BaseModelContainer
from common.logger import get_loading_progress_bar from common.logger import get_loading_progress_bar
from common.networking import handle_request_error from common.networking import handle_request_error
from common.tabby_config import config from common.tabby_config import config
from common.optional_dependencies import dependencies from common.optional_dependencies import dependencies
from common.utils import unwrap
# Global variables for model container
container: Optional[BaseModelContainer] = None
embeddings_container = None
# FIXME: Possibly use this solely when creating the model
if dependencies.exllamav2: if dependencies.exllamav2:
from backends.exllamav2.model import ExllamaV2Container from backends.exllamav2.model import ExllamaV2Container
# Global model container
container: Optional[ExllamaV2Container] = None
embeddings_container = None
if dependencies.extras: if dependencies.extras:
from backends.infinity.model import InfinityContainer from backends.infinity.model import InfinityContainer
@@ -41,6 +46,36 @@ def load_progress(module, modules):
yield module, modules yield module, modules
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
"""Sets overrides from a model folder's config yaml."""
override_config_path = model_dir / "tabby_config.yml"
if not override_config_path.exists():
return kwargs
async with aiofiles.open(
override_config_path, "r", encoding="utf8"
) as override_config_file:
contents = await override_config_file.read()
# Create a temporary YAML parser
yaml = YAML(typ="safe")
override_args = unwrap(yaml.load(contents), {})
# Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft_model"), {})
if draft_override_args:
kwargs["draft_model"] = {
**draft_override_args,
**unwrap(kwargs.get("draft_model"), {}),
}
# Merge the override and model kwargs
merged_kwargs = {**override_args, **kwargs}
return merged_kwargs
async def unload_model(skip_wait: bool = False, shutdown: bool = False): async def unload_model(skip_wait: bool = False, shutdown: bool = False):
"""Unloads a model""" """Unloads a model"""
global container global container
@@ -57,7 +92,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
if container and container.model: if container and container.model:
loaded_model_name = container.model_dir.name loaded_model_name = container.model_dir.name
if loaded_model_name == model_path.name and container.model_loaded: if loaded_model_name == model_path.name and container.loaded:
raise ValueError( raise ValueError(
f'Model "{loaded_model_name}" is already loaded! Aborting.' f'Model "{loaded_model_name}" is already loaded! Aborting.'
) )
@@ -65,22 +100,34 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
logger.info("Unloading existing model.") logger.info("Unloading existing model.")
await unload_model() await unload_model()
# Merge with config defaults # Reset to prepare for a new container
container = None
# Model_dir is already provided
if "model_dir" in kwargs:
kwargs.pop("model_dir")
# Merge with config and inline defaults
# TODO: Figure out a way to do this with Pydantic validation
# and ModelLoadRequest. Pydantic doesn't have async validators
kwargs = {**config.model_defaults, **kwargs} kwargs = {**config.model_defaults, **kwargs}
kwargs = await apply_inline_overrides(model_path, **kwargs)
# Create a new container # Create a new container
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs) new_container = await ExllamaV2Container.create(
model_path.resolve(), False, **kwargs
)
# Add possible types of models that can be loaded # Add possible types of models that can be loaded
model_type = [ModelType.MODEL] model_type = [ModelType.MODEL]
if container.use_vision: if new_container.use_vision:
model_type.insert(0, ModelType.VISION) model_type.insert(0, ModelType.VISION)
if container.draft_config: if new_container.draft_config:
model_type.insert(0, ModelType.DRAFT) model_type.insert(0, ModelType.DRAFT)
load_status = container.load_gen(load_progress, **kwargs) load_status = new_container.load_gen(load_progress, **kwargs)
progress = get_loading_progress_bar() progress = get_loading_progress_bar()
progress.start() progress.start()
@@ -104,6 +151,8 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
progress.stop() progress.stop()
else: else:
index += 1 index += 1
container = new_container
finally: finally:
progress.stop() progress.stop()
@@ -142,7 +191,7 @@ async def load_embedding_model(model_path: pathlib.Path, **kwargs):
if embeddings_container and embeddings_container.engine: if embeddings_container and embeddings_container.engine:
loaded_model_name = embeddings_container.model_dir.name loaded_model_name = embeddings_container.model_dir.name
if loaded_model_name == model_path.name and embeddings_container.model_loaded: if loaded_model_name == model_path.name and embeddings_container.loaded:
raise ValueError( raise ValueError(
f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.' f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.'
) )
@@ -150,8 +199,13 @@ async def load_embedding_model(model_path: pathlib.Path, **kwargs):
logger.info("Unloading existing embeddings model.") logger.info("Unloading existing embeddings model.")
await unload_embedding_model() await unload_embedding_model()
embeddings_container = InfinityContainer(model_path) # Reset to prepare for a new container
await embeddings_container.load(**kwargs) embeddings_container = None
new_embeddings_container = InfinityContainer(model_path)
await new_embeddings_container.load(**kwargs)
embeddings_container = new_embeddings_container
async def unload_embedding_model(): async def unload_embedding_model():
@@ -164,13 +218,13 @@ async def unload_embedding_model():
async def check_model_container(): async def check_model_container():
"""FastAPI depends that checks if a model isn't loaded or currently loading.""" """FastAPI depends that checks if a model isn't loaded or currently loading."""
if container is None or not (container.model_is_loading or container.model_loaded): if container is None:
error_message = handle_request_error( error_message = handle_request_error(
"No models are currently loaded.", "No models are currently loaded.",
exc_info=False, exc_info=False,
).error.message ).error.message
raise HTTPException(400, error_message) raise HTTPException(503, error_message)
async def check_embeddings_container(): async def check_embeddings_container():
@@ -180,12 +234,10 @@ async def check_embeddings_container():
This is the same as the model container check, but with embeddings instead. This is the same as the model container check, but with embeddings instead.
""" """
if embeddings_container is None or not ( if embeddings_container is None:
embeddings_container.model_is_loading or embeddings_container.model_loaded
):
error_message = handle_request_error( error_message = handle_request_error(
"No embedding models are currently loaded.", "No embedding models are currently loaded.",
exc_info=False, exc_info=False,
).error.message ).error.message
raise HTTPException(400, error_message) raise HTTPException(503, error_message)

View File

@@ -41,12 +41,6 @@ class BaseSamplerRequest(BaseModel):
ge=0, ge=0,
) )
generate_window: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("generate_window"),
examples=[512],
ge=0,
)
stop: Optional[Union[str, List[Union[str, int]]]] = Field( stop: Optional[Union[str, List[Union[str, int]]]] = Field(
default_factory=lambda: get_default_sampler_value("stop", []), default_factory=lambda: get_default_sampler_value("stop", []),
validation_alias=AliasChoices("stop", "stop_sequence"), validation_alias=AliasChoices("stop", "stop_sequence"),
@@ -165,7 +159,7 @@ class BaseSamplerRequest(BaseModel):
"rep_pen_range", "rep_pen_range",
), ),
description=( description=(
"Aliases: repetition_range, repetition_penalty_range, " "rep_pen_range" "Aliases: repetition_range, repetition_penalty_range, rep_pen_range"
), ),
) )
@@ -281,6 +275,11 @@ class BaseSamplerRequest(BaseModel):
ge=0, ge=0,
) )
logprobs: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("logprobs", 0),
ge=0,
)
@field_validator("top_k", mode="before") @field_validator("top_k", mode="before")
def convert_top_k(cls, v): def convert_top_k(cls, v):
"""Fixes instance if Top-K is -1.""" """Fixes instance if Top-K is -1."""

View File

@@ -1,5 +1,6 @@
"""Small replication of AutoTokenizer's chat template system for efficiency""" """Small replication of AutoTokenizer's chat template system for efficiency"""
import traceback
import aiofiles import aiofiles
import json import json
import pathlib import pathlib
@@ -211,3 +212,56 @@ def find_template_from_model(model_path: pathlib.Path):
return template_name return template_name
else: else:
raise TemplateLoadError("Could not find template from model name.") raise TemplateLoadError("Could not find template from model name.")
async def find_prompt_template(template_name, model_dir: pathlib.Path):
"""Tries to find a prompt template using various methods."""
logger.info("Attempting to load a prompt template if present.")
find_template_functions = [
lambda: PromptTemplate.from_model_json(
model_dir / "chat_template.json",
key="chat_template",
),
lambda: PromptTemplate.from_model_json(
model_dir / "tokenizer_config.json",
key="chat_template",
),
lambda: PromptTemplate.from_file(find_template_from_model(model_dir)),
]
# Find the template in the model directory if it exists
model_dir_template_path = model_dir / "tabby_template.jinja"
if model_dir_template_path.exists():
find_template_functions[:0] = [
lambda: PromptTemplate.from_file(model_dir_template_path)
]
# Add lookup from prompt template name if provided
if template_name:
find_template_functions[:0] = [
lambda: PromptTemplate.from_file(pathlib.Path("templates") / template_name),
lambda: PromptTemplate.from_model_json(
model_dir / "tokenizer_config.json",
key="chat_template",
name=template_name,
),
]
# Continue on exception since functions are tried as they fail
for template_func in find_template_functions:
try:
prompt_template = await template_func()
if prompt_template is not None:
return prompt_template
except TemplateLoadError as e:
logger.warning(f"TemplateLoadError: {str(e)}")
continue
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"An unexpected error happened when trying to load the template. "
"Trying other methods."
)
continue

View File

@@ -85,3 +85,24 @@ def unwrap_optional_type(type_hint) -> Type:
return arg return arg
return type_hint return type_hint
def calculate_rope_alpha(base_seq_len: int, target_seq_len: int):
"""
Converts a given max sequence length to a rope alpha value.
Args:
base_seq_len: The model's configured sequence length.
target_seq_len: The user-specified max sequence length.
"""
# Get the ratio of the model's max sequence length to the target
ratio = target_seq_len / base_seq_len
# Default to a 1 alpha if the sequence length is ever less
# than or equal to 1
if ratio <= 1.0:
alpha = 1
else:
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2
return alpha

View File

@@ -122,7 +122,7 @@ async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse:
async def get_max_length() -> MaxLengthResponse: async def get_max_length() -> MaxLengthResponse:
"""Fetches the max length of the model.""" """Fetches the max length of the model."""
max_length = model.container.get_model_parameters().get("max_seq_len") max_length = model.container.model_info().parameters.max_seq_len
return {"value": max_length} return {"value": max_length}

View File

@@ -52,7 +52,7 @@ async def _stream_collector(data: GenerateRequest, request: Request):
try: try:
logger.info(f"Received Kobold generation request {data.genkey}") logger.info(f"Received Kobold generation request {data.genkey}")
generator = model.container.generate_gen( generator = model.container.stream_generate(
request_id=data.genkey, abort_event=abort_event, **data.model_dump() request_id=data.genkey, abort_event=abort_event, **data.model_dump()
) )
async for generation in generator: async for generation in generator:

View File

@@ -32,10 +32,6 @@ class CommonCompletionRequest(BaseSamplerRequest):
# Generation info (remainder is in BaseSamplerRequest superclass) # Generation info (remainder is in BaseSamplerRequest superclass)
stream: Optional[bool] = False stream: Optional[bool] = False
stream_options: Optional[ChatCompletionStreamOptions] = None stream_options: Optional[ChatCompletionStreamOptions] = None
logprobs: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("logprobs", 0),
ge=0,
)
response_format: Optional[CompletionResponseFormat] = Field( response_format: Optional[CompletionResponseFormat] = Field(
default_factory=CompletionResponseFormat default_factory=CompletionResponseFormat
) )

View File

@@ -333,11 +333,11 @@ async def stream_generate_chat_completion(
_stream_collector( _stream_collector(
n, n,
gen_queue, gen_queue,
prompt,
request.state.id, request.state.id,
prompt,
task_gen_params,
abort_event, abort_event,
embeddings=embeddings, mm_embeddings=embeddings,
**task_gen_params.model_dump(exclude={"prompt"}),
) )
) )
@@ -422,10 +422,10 @@ async def generate_chat_completion(
gen_tasks.append( gen_tasks.append(
asyncio.create_task( asyncio.create_task(
model.container.generate( model.container.generate(
prompt,
request.state.id, request.state.id,
embeddings=embeddings, prompt,
**data.model_dump(exclude={"prompt"}), data,
mm_embeddings=embeddings,
) )
) )
) )
@@ -465,7 +465,6 @@ async def generate_tool_calls(
# FIXME: May not be necessary depending on how the codebase evolves # FIXME: May not be necessary depending on how the codebase evolves
tool_data = data.model_copy(deep=True) tool_data = data.model_copy(deep=True)
tool_data.json_schema = tool_data.tool_call_schema tool_data.json_schema = tool_data.tool_call_schema
gen_params = tool_data.model_dump()
for idx, gen in enumerate(generations): for idx, gen in enumerate(generations):
if gen["stop_str"] in tool_data.tool_call_start: if gen["stop_str"] in tool_data.tool_call_start:
@@ -488,10 +487,10 @@ async def generate_tool_calls(
gen_tasks.append( gen_tasks.append(
asyncio.create_task( asyncio.create_task(
model.container.generate( model.container.generate(
pre_tool_prompt,
request.state.id, request.state.id,
pre_tool_prompt,
tool_data,
embeddings=mm_embeddings, embeddings=mm_embeddings,
**gen_params,
) )
) )
) )

View File

@@ -8,12 +8,12 @@ import asyncio
import pathlib import pathlib
from asyncio import CancelledError from asyncio import CancelledError
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from typing import List, Union
from loguru import logger from loguru import logger
from typing import List, Optional, Union
from common import model from common import model
from common.auth import get_key_permission from common.auth import get_key_permission
from common.multimodal import MultimodalEmbeddingWrapper
from common.networking import ( from common.networking import (
get_generator_error, get_generator_error,
handle_request_disconnect, handle_request_disconnect,
@@ -86,16 +86,21 @@ def _create_response(
async def _stream_collector( async def _stream_collector(
task_idx: int, task_idx: int,
gen_queue: asyncio.Queue, gen_queue: asyncio.Queue,
prompt: str,
request_id: str, request_id: str,
prompt: str,
params: CompletionRequest,
abort_event: asyncio.Event, abort_event: asyncio.Event,
**kwargs, mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
): ):
"""Collects a stream and places results in a common queue""" """Collects a stream and places results in a common queue"""
try: try:
new_generation = model.container.generate_gen( new_generation = model.container.stream_generate(
prompt, request_id, abort_event, **kwargs request_id,
prompt,
params,
abort_event,
mm_embeddings,
) )
async for generation in new_generation: async for generation in new_generation:
generation["index"] = task_idx generation["index"] = task_idx
@@ -115,7 +120,7 @@ async def load_inline_model(model_name: str, request: Request):
if ( if (
model.container model.container
and model.container.model_dir.name == model_name and model.container.model_dir.name == model_name
and model.container.model_loaded and model.container.loaded
): ):
return return
@@ -195,10 +200,10 @@ async def stream_generate_completion(
_stream_collector( _stream_collector(
n, n,
gen_queue, gen_queue,
data.prompt,
request.state.id, request.state.id,
data.prompt,
task_gen_params,
abort_event, abort_event,
**task_gen_params.model_dump(exclude={"prompt"}),
) )
) )
@@ -256,9 +261,9 @@ async def generate_completion(
gen_tasks.append( gen_tasks.append(
asyncio.create_task( asyncio.create_task(
model.container.generate( model.container.generate(
data.prompt,
request.state.id, request.state.id,
**task_gen_params.model_dump(exclude={"prompt"}), data.prompt,
task_gen_params,
) )
) )
) )

View File

@@ -5,10 +5,8 @@ from typing import Optional
from common import model from common import model
from common.networking import get_generator_error, handle_request_disconnect from common.networking import get_generator_error, handle_request_disconnect
from common.tabby_config import config from common.tabby_config import config
from common.utils import unwrap
from endpoints.core.types.model import ( from endpoints.core.types.model import (
ModelCard, ModelCard,
ModelCardParameters,
ModelList, ModelList,
ModelLoadRequest, ModelLoadRequest,
ModelLoadResponse, ModelLoadResponse,
@@ -64,30 +62,7 @@ async def get_current_model_list(model_type: str = "model"):
def get_current_model(): def get_current_model():
"""Gets the current model with all parameters.""" """Gets the current model with all parameters."""
model_params = model.container.get_model_parameters() model_card = model.container.model_info()
draft_model_params = model_params.pop("draft", {})
if draft_model_params:
model_params["draft"] = ModelCard(
id=unwrap(draft_model_params.get("name"), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
else:
draft_model_params = None
model_card = ModelCard(
id=unwrap(model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(model_params),
logging=config.logging,
)
if draft_model_params:
draft_card = ModelCard(
id=unwrap(draft_model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
model_card.parameters.draft = draft_card
return model_card return model_card

View File

@@ -40,8 +40,6 @@ dependencies = [
"uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'", "uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'",
"winloop ; platform_system == 'Windows'", "winloop ; platform_system == 'Windows'",
"numpy < 2.0.0",
# For python 3.12 # For python 3.12
"setuptools ; python_version >= '3.12'" "setuptools ; python_version >= '3.12'"
] ]
@@ -60,55 +58,55 @@ dev = [
] ]
cu121 = [ cu121 = [
# Torch (Extra index URLs not support in pyproject.toml) # Torch (Extra index URLs not support in pyproject.toml)
"torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'", "torch @ https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'",
"torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", "torch @ https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", "torch @ https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", "torch @ https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'", "torch @ https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
"torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", "torch @ https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", "torch @ https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", "torch @ https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Exl2 # Exl2
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'", "exllamav2 @ https://github.com/kingbri1/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'",
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", "exllamav2 @ https://github.com/kingbri1/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", "exllamav2 @ https://github.com/kingbri1/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", "exllamav2 @ https://github.com/kingbri1/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'", "exllamav2 @ https://github.com/kingbri1/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", "exllamav2 @ https://github.com/kingbri1/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", "exllamav2 @ https://github.com/kingbri1/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", "exllamav2 @ https://github.com/kingbri1/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Windows FA2 from https://github.com/kingbri1/flash-attention/releases # Windows FA2 from https://github.com/kingbri1/flash-attention/releases
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu124torch2.6.0cxx11abiFALSE-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'", "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'",
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu124torch2.6.0cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu124torch2.6.0cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases # Linux FA2 from https://github.com/kingbri1/flash-attention/releases
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'", "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
] ]
amd = [ amd = [
# Torch triton for ROCm # Torch triton for ROCm
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.2.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'", "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.3.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.2.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.3.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.2.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.3.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.2.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.3.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Torch # Torch
"torch @ https://download.pytorch.org/whl/rocm6.2.4/torch-2.6.0%2Brocm6.2.4-cp313-cp313-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'", "torch @ https://download.pytorch.org/whl/rocm6.3/torch-2.7.0%2Brocm6.3-cp313-cp313-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
"torch @ https://download.pytorch.org/whl/rocm6.2.4/torch-2.6.0%2Brocm6.2.4-cp312-cp312-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", "torch @ https://download.pytorch.org/whl/rocm6.3/torch-2.7.0%2Brocm6.3-cp312-cp312-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"torch @ https://download.pytorch.org/whl/rocm6.2.4/torch-2.6.0%2Brocm6.2.4-cp311-cp311-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", "torch @ https://download.pytorch.org/whl/rocm6.3/torch-2.7.0%2Brocm6.3-cp311-cp311-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"torch @ https://download.pytorch.org/whl/rocm6.2.4/torch-2.6.0%2Brocm6.2.4-cp310-cp310-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", "torch @ https://download.pytorch.org/whl/rocm6.3/torch-2.7.0%2Brocm6.3-cp310-cp310-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Exl2 # Exl2
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.2.4.torch2.6.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'", "exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+rocm6.3.torch2.7.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.2.4.torch2.6.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", "exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+rocm6.3.torch2.7.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.2.4.torch2.6.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", "exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+rocm6.3.torch2.7.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.2.4.torch2.6.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", "exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+rocm6.3.torch2.7.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
] ]
# MARK: Ruff options # MARK: Ruff options

View File

@@ -14,9 +14,6 @@ max_tokens:
min_tokens: min_tokens:
override: 0 override: 0
force: false force: false
generate_window:
override: 512
force: false
stop: stop:
override: [] override: []
force: false force: false