mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-25 16:59:09 +00:00
244
backends/base_model_container.py
Normal file
244
backends/base_model_container.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
import abc
|
||||||
|
import asyncio
|
||||||
|
import pathlib
|
||||||
|
from loguru import logger
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
)
|
||||||
|
|
||||||
|
from common.multimodal import MultimodalEmbeddingWrapper
|
||||||
|
from common.sampling import BaseSamplerRequest
|
||||||
|
from common.templating import PromptTemplate
|
||||||
|
from common.transformers_utils import GenerationConfig
|
||||||
|
from endpoints.core.types.model import ModelCard
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModelContainer(abc.ABC):
|
||||||
|
"""Abstract base class for model containers."""
|
||||||
|
|
||||||
|
# Exposed model information
|
||||||
|
model_dir: pathlib.Path = pathlib.Path("models")
|
||||||
|
prompt_template: Optional[PromptTemplate] = None
|
||||||
|
generation_config: Optional[GenerationConfig] = None
|
||||||
|
|
||||||
|
# Load synchronization
|
||||||
|
# The bool is a master switch for accepting requests
|
||||||
|
# The lock keeps load tasks sequential
|
||||||
|
# The condition notifies any waiting tasks
|
||||||
|
active_job_ids: Dict[str, Any] = {}
|
||||||
|
loaded: bool = False
|
||||||
|
load_lock: asyncio.Lock
|
||||||
|
load_condition: asyncio.Condition
|
||||||
|
|
||||||
|
# Required methods
|
||||||
|
@classmethod
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def create(cls, model_directory: pathlib.Path, **kwargs):
|
||||||
|
"""
|
||||||
|
Asynchronously creates and initializes a model container instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_directory: Path to the model files.
|
||||||
|
**kwargs: Backend-specific configuration options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of the implementing class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def load(self, progress_callback=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Loads the model into memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
progress_callback: Optional callback for progress updates.
|
||||||
|
**kwargs: Additional loading options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
# NOTE: Might be an optional method
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def load_gen(self, progress_callback=None, **kwargs) -> AsyncIterator[Any]:
|
||||||
|
"""
|
||||||
|
Loads the model into memory, yielding progress updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
progress_callback: Optional callback for progress updates.
|
||||||
|
**kwargs: Additional loading options.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Progress updates
|
||||||
|
"""
|
||||||
|
|
||||||
|
if False:
|
||||||
|
yield
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def unload(self, loras_only: bool = False, **kwargs):
|
||||||
|
"""
|
||||||
|
Unloads the model and associated resources from memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loras_only: If True, only unload LoRAs.
|
||||||
|
**kwargs: Additional unloading options (e.g., shutdown).
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def encode_tokens(self, text: str, **kwargs) -> List[int]:
|
||||||
|
"""
|
||||||
|
Encodes a string of text into a list of token IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The input text string.
|
||||||
|
**kwargs: Backend-specific encoding options (e.g., add_bos_token).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of integer token IDs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def decode_tokens(self, ids: List[int], **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Decodes a list of token IDs back into a string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: A list of integer token IDs.
|
||||||
|
**kwargs: Backend-specific decoding options (e.g., decode_special_tokens).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The decoded text string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_special_tokens(self, **kwargs) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Gets special tokens used by the model/tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Options like add_bos_token, ban_eos_token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping special token names (e.g., 'bos_token', 'eos_token')
|
||||||
|
to their string or ID representation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: str,
|
||||||
|
params: BaseSamplerRequest,
|
||||||
|
abort_event: Optional[asyncio.Event] = None,
|
||||||
|
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generates a complete response for a given prompt and parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: Unique identifier for the generation request.
|
||||||
|
prompt: The input prompt string.
|
||||||
|
params: Sampling and generation parameters.
|
||||||
|
abort_event: An asyncio Event to signal cancellation.
|
||||||
|
mm_embeddings: Optional multimodal embeddings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing the generation info
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def stream_generate(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: str,
|
||||||
|
params: BaseSamplerRequest,
|
||||||
|
abort_event: Optional[asyncio.Event] = None,
|
||||||
|
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||||
|
) -> AsyncIterator[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Generates a response iteratively (streaming) for a given prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: Unique identifier for the generation request.
|
||||||
|
prompt: The input prompt string.
|
||||||
|
params: Sampling and generation parameters.
|
||||||
|
abort_event: An asyncio Event to signal cancellation.
|
||||||
|
mm_embeddings: Optional multimodal embeddings.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Generation chunks
|
||||||
|
"""
|
||||||
|
|
||||||
|
if False:
|
||||||
|
yield
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def model_info(self) -> ModelCard:
|
||||||
|
"""
|
||||||
|
Returns a dictionary of the current model's configuration parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model parameters provided by the backend
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def wait_for_jobs(self, skip_wait: bool = False):
|
||||||
|
"""
|
||||||
|
Waits for any active generation jobs to complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skip_wait: If True, cancel jobs immediately instead of waiting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Optional methods
|
||||||
|
async def load_loras(
|
||||||
|
self, lora_directory: pathlib.Path, **kwargs
|
||||||
|
) -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Loads LoRA adapters. Base implementation does nothing or raises error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lora_directory: Path to the directory containing LoRA files.
|
||||||
|
**kwargs: LoRA configuration (e.g., list of loras, scaling).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary indicating success/failure for each LoRA.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.warning("LoRA loading not implemented for this backend.") # type: ignore
|
||||||
|
return {
|
||||||
|
"success": [],
|
||||||
|
"failure": [
|
||||||
|
lora.get("name", "unknown") for lora in kwargs.get("loras", [])
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_loras(self) -> List[Any]:
|
||||||
|
"""
|
||||||
|
Gets the currently loaded LoRA adapters. Base implementation returns empty list.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list representing the loaded LoRAs (backend-specific format).
|
||||||
|
"""
|
||||||
|
|
||||||
|
return []
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -14,8 +14,7 @@ if dependencies.extras:
|
|||||||
|
|
||||||
class InfinityContainer:
|
class InfinityContainer:
|
||||||
model_dir: pathlib.Path
|
model_dir: pathlib.Path
|
||||||
model_is_loading: bool = False
|
loaded: bool = False
|
||||||
model_loaded: bool = False
|
|
||||||
|
|
||||||
# Use a runtime type hint here
|
# Use a runtime type hint here
|
||||||
engine: Optional["AsyncEmbeddingEngine"] = None
|
engine: Optional["AsyncEmbeddingEngine"] = None
|
||||||
@@ -24,8 +23,6 @@ class InfinityContainer:
|
|||||||
self.model_dir = model_directory
|
self.model_dir = model_directory
|
||||||
|
|
||||||
async def load(self, **kwargs):
|
async def load(self, **kwargs):
|
||||||
self.model_is_loading = True
|
|
||||||
|
|
||||||
# Use cpu by default
|
# Use cpu by default
|
||||||
device = unwrap(kwargs.get("embeddings_device"), "cpu")
|
device = unwrap(kwargs.get("embeddings_device"), "cpu")
|
||||||
|
|
||||||
@@ -40,7 +37,7 @@ class InfinityContainer:
|
|||||||
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
|
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
|
||||||
await self.engine.astart()
|
await self.engine.astart()
|
||||||
|
|
||||||
self.model_loaded = True
|
self.loaded = True
|
||||||
logger.info("Embedding model successfully loaded.")
|
logger.info("Embedding model successfully loaded.")
|
||||||
|
|
||||||
async def unload(self):
|
async def unload(self):
|
||||||
|
|||||||
@@ -4,24 +4,29 @@ Manages the storage and utility of model containers.
|
|||||||
Containers exist as a common interface for backends.
|
Containers exist as a common interface for backends.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
import pathlib
|
import pathlib
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from ruamel.yaml import YAML
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from backends.base_model_container import BaseModelContainer
|
||||||
from common.logger import get_loading_progress_bar
|
from common.logger import get_loading_progress_bar
|
||||||
from common.networking import handle_request_error
|
from common.networking import handle_request_error
|
||||||
from common.tabby_config import config
|
from common.tabby_config import config
|
||||||
from common.optional_dependencies import dependencies
|
from common.optional_dependencies import dependencies
|
||||||
|
from common.utils import unwrap
|
||||||
|
|
||||||
|
# Global variables for model container
|
||||||
|
container: Optional[BaseModelContainer] = None
|
||||||
|
embeddings_container = None
|
||||||
|
|
||||||
|
# FIXME: Possibly use this solely when creating the model
|
||||||
if dependencies.exllamav2:
|
if dependencies.exllamav2:
|
||||||
from backends.exllamav2.model import ExllamaV2Container
|
from backends.exllamav2.model import ExllamaV2Container
|
||||||
|
|
||||||
# Global model container
|
|
||||||
container: Optional[ExllamaV2Container] = None
|
|
||||||
embeddings_container = None
|
|
||||||
|
|
||||||
|
|
||||||
if dependencies.extras:
|
if dependencies.extras:
|
||||||
from backends.infinity.model import InfinityContainer
|
from backends.infinity.model import InfinityContainer
|
||||||
@@ -41,6 +46,36 @@ def load_progress(module, modules):
|
|||||||
yield module, modules
|
yield module, modules
|
||||||
|
|
||||||
|
|
||||||
|
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
|
||||||
|
"""Sets overrides from a model folder's config yaml."""
|
||||||
|
|
||||||
|
override_config_path = model_dir / "tabby_config.yml"
|
||||||
|
|
||||||
|
if not override_config_path.exists():
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
async with aiofiles.open(
|
||||||
|
override_config_path, "r", encoding="utf8"
|
||||||
|
) as override_config_file:
|
||||||
|
contents = await override_config_file.read()
|
||||||
|
|
||||||
|
# Create a temporary YAML parser
|
||||||
|
yaml = YAML(typ="safe")
|
||||||
|
override_args = unwrap(yaml.load(contents), {})
|
||||||
|
|
||||||
|
# Merge draft overrides beforehand
|
||||||
|
draft_override_args = unwrap(override_args.get("draft_model"), {})
|
||||||
|
if draft_override_args:
|
||||||
|
kwargs["draft_model"] = {
|
||||||
|
**draft_override_args,
|
||||||
|
**unwrap(kwargs.get("draft_model"), {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Merge the override and model kwargs
|
||||||
|
merged_kwargs = {**override_args, **kwargs}
|
||||||
|
return merged_kwargs
|
||||||
|
|
||||||
|
|
||||||
async def unload_model(skip_wait: bool = False, shutdown: bool = False):
|
async def unload_model(skip_wait: bool = False, shutdown: bool = False):
|
||||||
"""Unloads a model"""
|
"""Unloads a model"""
|
||||||
global container
|
global container
|
||||||
@@ -57,7 +92,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||||||
if container and container.model:
|
if container and container.model:
|
||||||
loaded_model_name = container.model_dir.name
|
loaded_model_name = container.model_dir.name
|
||||||
|
|
||||||
if loaded_model_name == model_path.name and container.model_loaded:
|
if loaded_model_name == model_path.name and container.loaded:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'Model "{loaded_model_name}" is already loaded! Aborting.'
|
f'Model "{loaded_model_name}" is already loaded! Aborting.'
|
||||||
)
|
)
|
||||||
@@ -65,22 +100,34 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||||||
logger.info("Unloading existing model.")
|
logger.info("Unloading existing model.")
|
||||||
await unload_model()
|
await unload_model()
|
||||||
|
|
||||||
# Merge with config defaults
|
# Reset to prepare for a new container
|
||||||
|
container = None
|
||||||
|
|
||||||
|
# Model_dir is already provided
|
||||||
|
if "model_dir" in kwargs:
|
||||||
|
kwargs.pop("model_dir")
|
||||||
|
|
||||||
|
# Merge with config and inline defaults
|
||||||
|
# TODO: Figure out a way to do this with Pydantic validation
|
||||||
|
# and ModelLoadRequest. Pydantic doesn't have async validators
|
||||||
kwargs = {**config.model_defaults, **kwargs}
|
kwargs = {**config.model_defaults, **kwargs}
|
||||||
|
kwargs = await apply_inline_overrides(model_path, **kwargs)
|
||||||
|
|
||||||
# Create a new container
|
# Create a new container
|
||||||
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
|
new_container = await ExllamaV2Container.create(
|
||||||
|
model_path.resolve(), False, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
# Add possible types of models that can be loaded
|
# Add possible types of models that can be loaded
|
||||||
model_type = [ModelType.MODEL]
|
model_type = [ModelType.MODEL]
|
||||||
|
|
||||||
if container.use_vision:
|
if new_container.use_vision:
|
||||||
model_type.insert(0, ModelType.VISION)
|
model_type.insert(0, ModelType.VISION)
|
||||||
|
|
||||||
if container.draft_config:
|
if new_container.draft_config:
|
||||||
model_type.insert(0, ModelType.DRAFT)
|
model_type.insert(0, ModelType.DRAFT)
|
||||||
|
|
||||||
load_status = container.load_gen(load_progress, **kwargs)
|
load_status = new_container.load_gen(load_progress, **kwargs)
|
||||||
|
|
||||||
progress = get_loading_progress_bar()
|
progress = get_loading_progress_bar()
|
||||||
progress.start()
|
progress.start()
|
||||||
@@ -104,6 +151,8 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||||||
progress.stop()
|
progress.stop()
|
||||||
else:
|
else:
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
|
container = new_container
|
||||||
finally:
|
finally:
|
||||||
progress.stop()
|
progress.stop()
|
||||||
|
|
||||||
@@ -142,7 +191,7 @@ async def load_embedding_model(model_path: pathlib.Path, **kwargs):
|
|||||||
if embeddings_container and embeddings_container.engine:
|
if embeddings_container and embeddings_container.engine:
|
||||||
loaded_model_name = embeddings_container.model_dir.name
|
loaded_model_name = embeddings_container.model_dir.name
|
||||||
|
|
||||||
if loaded_model_name == model_path.name and embeddings_container.model_loaded:
|
if loaded_model_name == model_path.name and embeddings_container.loaded:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.'
|
f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.'
|
||||||
)
|
)
|
||||||
@@ -150,8 +199,13 @@ async def load_embedding_model(model_path: pathlib.Path, **kwargs):
|
|||||||
logger.info("Unloading existing embeddings model.")
|
logger.info("Unloading existing embeddings model.")
|
||||||
await unload_embedding_model()
|
await unload_embedding_model()
|
||||||
|
|
||||||
embeddings_container = InfinityContainer(model_path)
|
# Reset to prepare for a new container
|
||||||
await embeddings_container.load(**kwargs)
|
embeddings_container = None
|
||||||
|
|
||||||
|
new_embeddings_container = InfinityContainer(model_path)
|
||||||
|
await new_embeddings_container.load(**kwargs)
|
||||||
|
|
||||||
|
embeddings_container = new_embeddings_container
|
||||||
|
|
||||||
|
|
||||||
async def unload_embedding_model():
|
async def unload_embedding_model():
|
||||||
@@ -164,13 +218,13 @@ async def unload_embedding_model():
|
|||||||
async def check_model_container():
|
async def check_model_container():
|
||||||
"""FastAPI depends that checks if a model isn't loaded or currently loading."""
|
"""FastAPI depends that checks if a model isn't loaded or currently loading."""
|
||||||
|
|
||||||
if container is None or not (container.model_is_loading or container.model_loaded):
|
if container is None:
|
||||||
error_message = handle_request_error(
|
error_message = handle_request_error(
|
||||||
"No models are currently loaded.",
|
"No models are currently loaded.",
|
||||||
exc_info=False,
|
exc_info=False,
|
||||||
).error.message
|
).error.message
|
||||||
|
|
||||||
raise HTTPException(400, error_message)
|
raise HTTPException(503, error_message)
|
||||||
|
|
||||||
|
|
||||||
async def check_embeddings_container():
|
async def check_embeddings_container():
|
||||||
@@ -180,12 +234,10 @@ async def check_embeddings_container():
|
|||||||
This is the same as the model container check, but with embeddings instead.
|
This is the same as the model container check, but with embeddings instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if embeddings_container is None or not (
|
if embeddings_container is None:
|
||||||
embeddings_container.model_is_loading or embeddings_container.model_loaded
|
|
||||||
):
|
|
||||||
error_message = handle_request_error(
|
error_message = handle_request_error(
|
||||||
"No embedding models are currently loaded.",
|
"No embedding models are currently loaded.",
|
||||||
exc_info=False,
|
exc_info=False,
|
||||||
).error.message
|
).error.message
|
||||||
|
|
||||||
raise HTTPException(400, error_message)
|
raise HTTPException(503, error_message)
|
||||||
|
|||||||
@@ -41,12 +41,6 @@ class BaseSamplerRequest(BaseModel):
|
|||||||
ge=0,
|
ge=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
generate_window: Optional[int] = Field(
|
|
||||||
default_factory=lambda: get_default_sampler_value("generate_window"),
|
|
||||||
examples=[512],
|
|
||||||
ge=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
stop: Optional[Union[str, List[Union[str, int]]]] = Field(
|
stop: Optional[Union[str, List[Union[str, int]]]] = Field(
|
||||||
default_factory=lambda: get_default_sampler_value("stop", []),
|
default_factory=lambda: get_default_sampler_value("stop", []),
|
||||||
validation_alias=AliasChoices("stop", "stop_sequence"),
|
validation_alias=AliasChoices("stop", "stop_sequence"),
|
||||||
@@ -165,7 +159,7 @@ class BaseSamplerRequest(BaseModel):
|
|||||||
"rep_pen_range",
|
"rep_pen_range",
|
||||||
),
|
),
|
||||||
description=(
|
description=(
|
||||||
"Aliases: repetition_range, repetition_penalty_range, " "rep_pen_range"
|
"Aliases: repetition_range, repetition_penalty_range, rep_pen_range"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -281,6 +275,11 @@ class BaseSamplerRequest(BaseModel):
|
|||||||
ge=0,
|
ge=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logprobs: Optional[int] = Field(
|
||||||
|
default_factory=lambda: get_default_sampler_value("logprobs", 0),
|
||||||
|
ge=0,
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("top_k", mode="before")
|
@field_validator("top_k", mode="before")
|
||||||
def convert_top_k(cls, v):
|
def convert_top_k(cls, v):
|
||||||
"""Fixes instance if Top-K is -1."""
|
"""Fixes instance if Top-K is -1."""
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
||||||
|
|
||||||
|
import traceback
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
@@ -211,3 +212,56 @@ def find_template_from_model(model_path: pathlib.Path):
|
|||||||
return template_name
|
return template_name
|
||||||
else:
|
else:
|
||||||
raise TemplateLoadError("Could not find template from model name.")
|
raise TemplateLoadError("Could not find template from model name.")
|
||||||
|
|
||||||
|
|
||||||
|
async def find_prompt_template(template_name, model_dir: pathlib.Path):
|
||||||
|
"""Tries to find a prompt template using various methods."""
|
||||||
|
|
||||||
|
logger.info("Attempting to load a prompt template if present.")
|
||||||
|
|
||||||
|
find_template_functions = [
|
||||||
|
lambda: PromptTemplate.from_model_json(
|
||||||
|
model_dir / "chat_template.json",
|
||||||
|
key="chat_template",
|
||||||
|
),
|
||||||
|
lambda: PromptTemplate.from_model_json(
|
||||||
|
model_dir / "tokenizer_config.json",
|
||||||
|
key="chat_template",
|
||||||
|
),
|
||||||
|
lambda: PromptTemplate.from_file(find_template_from_model(model_dir)),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Find the template in the model directory if it exists
|
||||||
|
model_dir_template_path = model_dir / "tabby_template.jinja"
|
||||||
|
if model_dir_template_path.exists():
|
||||||
|
find_template_functions[:0] = [
|
||||||
|
lambda: PromptTemplate.from_file(model_dir_template_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add lookup from prompt template name if provided
|
||||||
|
if template_name:
|
||||||
|
find_template_functions[:0] = [
|
||||||
|
lambda: PromptTemplate.from_file(pathlib.Path("templates") / template_name),
|
||||||
|
lambda: PromptTemplate.from_model_json(
|
||||||
|
model_dir / "tokenizer_config.json",
|
||||||
|
key="chat_template",
|
||||||
|
name=template_name,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Continue on exception since functions are tried as they fail
|
||||||
|
for template_func in find_template_functions:
|
||||||
|
try:
|
||||||
|
prompt_template = await template_func()
|
||||||
|
if prompt_template is not None:
|
||||||
|
return prompt_template
|
||||||
|
except TemplateLoadError as e:
|
||||||
|
logger.warning(f"TemplateLoadError: {str(e)}")
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.warning(
|
||||||
|
"An unexpected error happened when trying to load the template. "
|
||||||
|
"Trying other methods."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|||||||
@@ -85,3 +85,24 @@ def unwrap_optional_type(type_hint) -> Type:
|
|||||||
return arg
|
return arg
|
||||||
|
|
||||||
return type_hint
|
return type_hint
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_rope_alpha(base_seq_len: int, target_seq_len: int):
|
||||||
|
"""
|
||||||
|
Converts a given max sequence length to a rope alpha value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_seq_len: The model's configured sequence length.
|
||||||
|
target_seq_len: The user-specified max sequence length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get the ratio of the model's max sequence length to the target
|
||||||
|
ratio = target_seq_len / base_seq_len
|
||||||
|
|
||||||
|
# Default to a 1 alpha if the sequence length is ever less
|
||||||
|
# than or equal to 1
|
||||||
|
if ratio <= 1.0:
|
||||||
|
alpha = 1
|
||||||
|
else:
|
||||||
|
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2
|
||||||
|
return alpha
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse:
|
|||||||
async def get_max_length() -> MaxLengthResponse:
|
async def get_max_length() -> MaxLengthResponse:
|
||||||
"""Fetches the max length of the model."""
|
"""Fetches the max length of the model."""
|
||||||
|
|
||||||
max_length = model.container.get_model_parameters().get("max_seq_len")
|
max_length = model.container.model_info().parameters.max_seq_len
|
||||||
return {"value": max_length}
|
return {"value": max_length}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ async def _stream_collector(data: GenerateRequest, request: Request):
|
|||||||
try:
|
try:
|
||||||
logger.info(f"Received Kobold generation request {data.genkey}")
|
logger.info(f"Received Kobold generation request {data.genkey}")
|
||||||
|
|
||||||
generator = model.container.generate_gen(
|
generator = model.container.stream_generate(
|
||||||
request_id=data.genkey, abort_event=abort_event, **data.model_dump()
|
request_id=data.genkey, abort_event=abort_event, **data.model_dump()
|
||||||
)
|
)
|
||||||
async for generation in generator:
|
async for generation in generator:
|
||||||
|
|||||||
@@ -32,10 +32,6 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
|||||||
# Generation info (remainder is in BaseSamplerRequest superclass)
|
# Generation info (remainder is in BaseSamplerRequest superclass)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
stream_options: Optional[ChatCompletionStreamOptions] = None
|
stream_options: Optional[ChatCompletionStreamOptions] = None
|
||||||
logprobs: Optional[int] = Field(
|
|
||||||
default_factory=lambda: get_default_sampler_value("logprobs", 0),
|
|
||||||
ge=0,
|
|
||||||
)
|
|
||||||
response_format: Optional[CompletionResponseFormat] = Field(
|
response_format: Optional[CompletionResponseFormat] = Field(
|
||||||
default_factory=CompletionResponseFormat
|
default_factory=CompletionResponseFormat
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -333,11 +333,11 @@ async def stream_generate_chat_completion(
|
|||||||
_stream_collector(
|
_stream_collector(
|
||||||
n,
|
n,
|
||||||
gen_queue,
|
gen_queue,
|
||||||
prompt,
|
|
||||||
request.state.id,
|
request.state.id,
|
||||||
|
prompt,
|
||||||
|
task_gen_params,
|
||||||
abort_event,
|
abort_event,
|
||||||
embeddings=embeddings,
|
mm_embeddings=embeddings,
|
||||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -422,10 +422,10 @@ async def generate_chat_completion(
|
|||||||
gen_tasks.append(
|
gen_tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
model.container.generate(
|
model.container.generate(
|
||||||
prompt,
|
|
||||||
request.state.id,
|
request.state.id,
|
||||||
embeddings=embeddings,
|
prompt,
|
||||||
**data.model_dump(exclude={"prompt"}),
|
data,
|
||||||
|
mm_embeddings=embeddings,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -465,7 +465,6 @@ async def generate_tool_calls(
|
|||||||
# FIXME: May not be necessary depending on how the codebase evolves
|
# FIXME: May not be necessary depending on how the codebase evolves
|
||||||
tool_data = data.model_copy(deep=True)
|
tool_data = data.model_copy(deep=True)
|
||||||
tool_data.json_schema = tool_data.tool_call_schema
|
tool_data.json_schema = tool_data.tool_call_schema
|
||||||
gen_params = tool_data.model_dump()
|
|
||||||
|
|
||||||
for idx, gen in enumerate(generations):
|
for idx, gen in enumerate(generations):
|
||||||
if gen["stop_str"] in tool_data.tool_call_start:
|
if gen["stop_str"] in tool_data.tool_call_start:
|
||||||
@@ -488,10 +487,10 @@ async def generate_tool_calls(
|
|||||||
gen_tasks.append(
|
gen_tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
model.container.generate(
|
model.container.generate(
|
||||||
pre_tool_prompt,
|
|
||||||
request.state.id,
|
request.state.id,
|
||||||
|
pre_tool_prompt,
|
||||||
|
tool_data,
|
||||||
embeddings=mm_embeddings,
|
embeddings=mm_embeddings,
|
||||||
**gen_params,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ import asyncio
|
|||||||
import pathlib
|
import pathlib
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from common import model
|
from common import model
|
||||||
from common.auth import get_key_permission
|
from common.auth import get_key_permission
|
||||||
|
from common.multimodal import MultimodalEmbeddingWrapper
|
||||||
from common.networking import (
|
from common.networking import (
|
||||||
get_generator_error,
|
get_generator_error,
|
||||||
handle_request_disconnect,
|
handle_request_disconnect,
|
||||||
@@ -86,16 +86,21 @@ def _create_response(
|
|||||||
async def _stream_collector(
|
async def _stream_collector(
|
||||||
task_idx: int,
|
task_idx: int,
|
||||||
gen_queue: asyncio.Queue,
|
gen_queue: asyncio.Queue,
|
||||||
prompt: str,
|
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
prompt: str,
|
||||||
|
params: CompletionRequest,
|
||||||
abort_event: asyncio.Event,
|
abort_event: asyncio.Event,
|
||||||
**kwargs,
|
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||||
):
|
):
|
||||||
"""Collects a stream and places results in a common queue"""
|
"""Collects a stream and places results in a common queue"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_generation = model.container.generate_gen(
|
new_generation = model.container.stream_generate(
|
||||||
prompt, request_id, abort_event, **kwargs
|
request_id,
|
||||||
|
prompt,
|
||||||
|
params,
|
||||||
|
abort_event,
|
||||||
|
mm_embeddings,
|
||||||
)
|
)
|
||||||
async for generation in new_generation:
|
async for generation in new_generation:
|
||||||
generation["index"] = task_idx
|
generation["index"] = task_idx
|
||||||
@@ -115,7 +120,7 @@ async def load_inline_model(model_name: str, request: Request):
|
|||||||
if (
|
if (
|
||||||
model.container
|
model.container
|
||||||
and model.container.model_dir.name == model_name
|
and model.container.model_dir.name == model_name
|
||||||
and model.container.model_loaded
|
and model.container.loaded
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -195,10 +200,10 @@ async def stream_generate_completion(
|
|||||||
_stream_collector(
|
_stream_collector(
|
||||||
n,
|
n,
|
||||||
gen_queue,
|
gen_queue,
|
||||||
data.prompt,
|
|
||||||
request.state.id,
|
request.state.id,
|
||||||
|
data.prompt,
|
||||||
|
task_gen_params,
|
||||||
abort_event,
|
abort_event,
|
||||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -256,9 +261,9 @@ async def generate_completion(
|
|||||||
gen_tasks.append(
|
gen_tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
model.container.generate(
|
model.container.generate(
|
||||||
data.prompt,
|
|
||||||
request.state.id,
|
request.state.id,
|
||||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
data.prompt,
|
||||||
|
task_gen_params,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,10 +5,8 @@ from typing import Optional
|
|||||||
from common import model
|
from common import model
|
||||||
from common.networking import get_generator_error, handle_request_disconnect
|
from common.networking import get_generator_error, handle_request_disconnect
|
||||||
from common.tabby_config import config
|
from common.tabby_config import config
|
||||||
from common.utils import unwrap
|
|
||||||
from endpoints.core.types.model import (
|
from endpoints.core.types.model import (
|
||||||
ModelCard,
|
ModelCard,
|
||||||
ModelCardParameters,
|
|
||||||
ModelList,
|
ModelList,
|
||||||
ModelLoadRequest,
|
ModelLoadRequest,
|
||||||
ModelLoadResponse,
|
ModelLoadResponse,
|
||||||
@@ -64,30 +62,7 @@ async def get_current_model_list(model_type: str = "model"):
|
|||||||
def get_current_model():
|
def get_current_model():
|
||||||
"""Gets the current model with all parameters."""
|
"""Gets the current model with all parameters."""
|
||||||
|
|
||||||
model_params = model.container.get_model_parameters()
|
model_card = model.container.model_info()
|
||||||
draft_model_params = model_params.pop("draft", {})
|
|
||||||
|
|
||||||
if draft_model_params:
|
|
||||||
model_params["draft"] = ModelCard(
|
|
||||||
id=unwrap(draft_model_params.get("name"), "unknown"),
|
|
||||||
parameters=ModelCardParameters.model_validate(draft_model_params),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
draft_model_params = None
|
|
||||||
|
|
||||||
model_card = ModelCard(
|
|
||||||
id=unwrap(model_params.pop("name", None), "unknown"),
|
|
||||||
parameters=ModelCardParameters.model_validate(model_params),
|
|
||||||
logging=config.logging,
|
|
||||||
)
|
|
||||||
|
|
||||||
if draft_model_params:
|
|
||||||
draft_card = ModelCard(
|
|
||||||
id=unwrap(draft_model_params.pop("name", None), "unknown"),
|
|
||||||
parameters=ModelCardParameters.model_validate(draft_model_params),
|
|
||||||
)
|
|
||||||
|
|
||||||
model_card.parameters.draft = draft_card
|
|
||||||
|
|
||||||
return model_card
|
return model_card
|
||||||
|
|
||||||
|
|||||||
@@ -40,8 +40,6 @@ dependencies = [
|
|||||||
"uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
"uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||||
"winloop ; platform_system == 'Windows'",
|
"winloop ; platform_system == 'Windows'",
|
||||||
|
|
||||||
"numpy < 2.0.0",
|
|
||||||
|
|
||||||
# For python 3.12
|
# For python 3.12
|
||||||
"setuptools ; python_version >= '3.12'"
|
"setuptools ; python_version >= '3.12'"
|
||||||
]
|
]
|
||||||
@@ -60,55 +58,55 @@ dev = [
|
|||||||
]
|
]
|
||||||
cu121 = [
|
cu121 = [
|
||||||
# Torch (Extra index URLs not support in pyproject.toml)
|
# Torch (Extra index URLs not support in pyproject.toml)
|
||||||
"torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'",
|
"torch @ https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'",
|
||||||
"torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
"torch @ https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||||
"torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
"torch @ https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||||
"torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
"torch @ https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||||
"torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
|
"torch @ https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
|
||||||
"torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
"torch @ https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||||
"torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
"torch @ https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||||
"torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
"torch @ https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||||
|
|
||||||
# Exl2
|
# Exl2
|
||||||
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'",
|
"exllamav2 @ https://github.com/kingbri1/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'",
|
||||||
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
"exllamav2 @ https://github.com/kingbri1/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||||
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
"exllamav2 @ https://github.com/kingbri1/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||||
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
"exllamav2 @ https://github.com/kingbri1/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||||
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
|
"exllamav2 @ https://github.com/kingbri1/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
|
||||||
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
"exllamav2 @ https://github.com/kingbri1/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||||
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
"exllamav2 @ https://github.com/kingbri1/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||||
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
"exllamav2 @ https://github.com/kingbri1/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||||
|
|
||||||
# Windows FA2 from https://github.com/kingbri1/flash-attention/releases
|
# Windows FA2 from https://github.com/kingbri1/flash-attention/releases
|
||||||
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu124torch2.6.0cxx11abiFALSE-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'",
|
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'",
|
||||||
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu124torch2.6.0cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||||
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu124torch2.6.0cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||||
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu124torch2.6.0cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||||
|
|
||||||
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
|
# Linux FA2 from https://github.com/kingbri1/flash-attention/releases
|
||||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
|
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
|
||||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||||
]
|
]
|
||||||
amd = [
|
amd = [
|
||||||
# Torch triton for ROCm
|
# Torch triton for ROCm
|
||||||
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.2.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
|
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.3.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
|
||||||
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.2.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.3.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||||
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.2.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.3.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||||
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.2.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.3.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||||
|
|
||||||
# Torch
|
# Torch
|
||||||
"torch @ https://download.pytorch.org/whl/rocm6.2.4/torch-2.6.0%2Brocm6.2.4-cp313-cp313-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
|
"torch @ https://download.pytorch.org/whl/rocm6.3/torch-2.7.0%2Brocm6.3-cp313-cp313-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
|
||||||
"torch @ https://download.pytorch.org/whl/rocm6.2.4/torch-2.6.0%2Brocm6.2.4-cp312-cp312-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
"torch @ https://download.pytorch.org/whl/rocm6.3/torch-2.7.0%2Brocm6.3-cp312-cp312-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||||
"torch @ https://download.pytorch.org/whl/rocm6.2.4/torch-2.6.0%2Brocm6.2.4-cp311-cp311-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
"torch @ https://download.pytorch.org/whl/rocm6.3/torch-2.7.0%2Brocm6.3-cp311-cp311-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||||
"torch @ https://download.pytorch.org/whl/rocm6.2.4/torch-2.6.0%2Brocm6.2.4-cp310-cp310-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
"torch @ https://download.pytorch.org/whl/rocm6.3/torch-2.7.0%2Brocm6.3-cp310-cp310-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||||
|
|
||||||
# Exl2
|
# Exl2
|
||||||
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.2.4.torch2.6.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
|
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+rocm6.3.torch2.7.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
|
||||||
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.2.4.torch2.6.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+rocm6.3.torch2.7.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||||
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.2.4.torch2.6.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+rocm6.3.torch2.7.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||||
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.2.4.torch2.6.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+rocm6.3.torch2.7.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||||
]
|
]
|
||||||
|
|
||||||
# MARK: Ruff options
|
# MARK: Ruff options
|
||||||
|
|||||||
@@ -14,9 +14,6 @@ max_tokens:
|
|||||||
min_tokens:
|
min_tokens:
|
||||||
override: 0
|
override: 0
|
||||||
force: false
|
force: false
|
||||||
generate_window:
|
|
||||||
override: 512
|
|
||||||
force: false
|
|
||||||
stop:
|
stop:
|
||||||
override: []
|
override: []
|
||||||
force: false
|
force: false
|
||||||
|
|||||||
Reference in New Issue
Block a user