diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 7e8402f..f3fb26e 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -2,6 +2,7 @@ import asyncio import gc import pathlib import re +import torch from itertools import zip_longest from typing import ( Any, @@ -11,7 +12,6 @@ from typing import ( Optional, ) -import torch from exllamav3 import ( AsyncGenerator, AsyncJob, @@ -229,7 +229,9 @@ class ExllamaV3Container(BaseModelContainer): self.cache = self.create_cache(self.cache_mode, self.model) # Limit max_seq_len to prevent sequences larger than the cache - max_seq_len = unwrap(kwargs.get("max_seq_len"), hf_model.hf_config.max_position_embeddings) + max_seq_len = unwrap( + kwargs.get("max_seq_len"), hf_model.hf_config.max_position_embeddings + ) self.max_seq_len = self.adjust_max_seq_len(max_seq_len) # Draft cache @@ -967,7 +969,7 @@ class ExllamaV3Container(BaseModelContainer): banned_strings=params.banned_strings, embeddings=mm_embeddings_content, return_top_tokens=params.logprobs, - max_rq_tokens=self.max_rq_tokens + max_rq_tokens=self.max_rq_tokens, ) generated_tokens = 0 diff --git a/backends/exllamav3/utils.py b/backends/exllamav3/utils.py index 0a90487..5a3e68d 100644 --- a/backends/exllamav3/utils.py +++ b/backends/exllamav3/utils.py @@ -1,6 +1,8 @@ import platform +import torch from loguru import logger + def exllama_supports_nccl(): if platform.system() == "Windows": unsupported_message = ( @@ -9,5 +11,4 @@ def exllama_supports_nccl(): logger.warning(unsupported_message) return False - import torch return torch.cuda.is_available() and torch.distributed.is_nccl_available()