Merge remote-tracking branch 'origin/main_seq' into main_seq

This commit is contained in:
turboderp
2025-10-14 00:58:42 +02:00
2 changed files with 7 additions and 4 deletions

View File

@@ -2,6 +2,7 @@ import asyncio
import gc
import pathlib
import re
import torch
from itertools import zip_longest
from typing import (
Any,
@@ -11,7 +12,6 @@ from typing import (
Optional,
)
import torch
from exllamav3 import (
AsyncGenerator,
AsyncJob,
@@ -229,7 +229,9 @@ class ExllamaV3Container(BaseModelContainer):
self.cache = self.create_cache(self.cache_mode, self.model)
# Limit max_seq_len to prevent sequences larger than the cache
max_seq_len = unwrap(kwargs.get("max_seq_len"), hf_model.hf_config.max_position_embeddings)
max_seq_len = unwrap(
kwargs.get("max_seq_len"), hf_model.hf_config.max_position_embeddings
)
self.max_seq_len = self.adjust_max_seq_len(max_seq_len)
# Draft cache
@@ -967,7 +969,7 @@ class ExllamaV3Container(BaseModelContainer):
banned_strings=params.banned_strings,
embeddings=mm_embeddings_content,
return_top_tokens=params.logprobs,
max_rq_tokens=self.max_rq_tokens
max_rq_tokens=self.max_rq_tokens,
)
generated_tokens = 0

View File

@@ -1,6 +1,8 @@
import platform
import torch
from loguru import logger
def exllama_supports_nccl():
if platform.system() == "Windows":
unsupported_message = (
@@ -9,5 +11,4 @@ def exllama_supports_nccl():
logger.warning(unsupported_message)
return False
import torch
return torch.cuda.is_available() and torch.distributed.is_nccl_available()