mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-30 03:01:44 +00:00
Model: Add exl3 and associated load functions
Initial exl3 compat and loading functionality. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -64,16 +64,19 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
|
||||
# Exl2 vars
|
||||
config: Optional[ExLlamaV2Config] = None
|
||||
draft_config: Optional[ExLlamaV2Config] = None
|
||||
model: Optional[ExLlamaV2] = None
|
||||
draft_model: Optional[ExLlamaV2] = None
|
||||
cache: Optional[ExLlamaV2Cache] = None
|
||||
draft_cache: Optional[ExLlamaV2Cache] = None
|
||||
tokenizer: Optional[ExLlamaV2Tokenizer] = None
|
||||
generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
paged: bool = True
|
||||
|
||||
# Draft model vars
|
||||
use_draft_model: bool = False
|
||||
draft_config: Optional[ExLlamaV2Config] = None
|
||||
draft_model: Optional[ExLlamaV2] = None
|
||||
draft_cache: Optional[ExLlamaV2Cache] = None
|
||||
|
||||
# Internal config vars
|
||||
cache_size: int = None
|
||||
cache_mode: str = "FP16"
|
||||
@@ -100,7 +103,7 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
load_condition: asyncio.Condition = asyncio.Condition()
|
||||
|
||||
@classmethod
|
||||
async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
|
||||
async def create(cls, model_directory: pathlib.Path, **kwargs):
|
||||
"""
|
||||
Primary asynchronous initializer for model container.
|
||||
|
||||
@@ -110,8 +113,6 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
# Create a new instance as a "fake self"
|
||||
self = cls()
|
||||
|
||||
self.quiet = quiet
|
||||
|
||||
# Initialize config
|
||||
self.config = ExLlamaV2Config()
|
||||
self.model_dir = model_directory
|
||||
@@ -122,6 +123,7 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
self.config.max_seq_len = 4096
|
||||
|
||||
self.config.prepare()
|
||||
print(self.config.max_seq_len)
|
||||
|
||||
# Check if the model arch is compatible with various exl2 features
|
||||
self.config.arch_compat_overrides()
|
||||
@@ -162,7 +164,7 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
# Prepare the draft model config if necessary
|
||||
draft_args = unwrap(kwargs.get("draft_model"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
enable_draft = draft_args and draft_model_name
|
||||
self.use_draft_model = draft_args and draft_model_name
|
||||
|
||||
# Always disable draft if params are incorrectly configured
|
||||
if draft_args and draft_model_name is None:
|
||||
@@ -170,9 +172,9 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
"Draft model is disabled because a model name "
|
||||
"wasn't provided. Please check your config.yml!"
|
||||
)
|
||||
enable_draft = False
|
||||
self.use_draft_model = False
|
||||
|
||||
if enable_draft:
|
||||
if self.use_draft_model:
|
||||
self.draft_config = ExLlamaV2Config()
|
||||
draft_model_path = pathlib.Path(
|
||||
unwrap(draft_args.get("draft_model_dir"), "models")
|
||||
@@ -365,7 +367,7 @@ class ExllamaV2Container(BaseModelContainer):
|
||||
self.config.max_attention_size = chunk_size**2
|
||||
|
||||
# Set user-configured draft model values
|
||||
if enable_draft:
|
||||
if self.use_draft_model:
|
||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||
|
||||
self.draft_config.scale_pos_emb = unwrap(
|
||||
|
||||
Reference in New Issue
Block a user