Exl3: Couldn't wait

Just copied some stuff around and it ended up working for basic use.
This commit is contained in:
randoentity
2025-04-29 23:57:53 +02:00
committed by kingbri
parent b4ff2f23cf
commit daae9ec43d
2 changed files with 221 additions and 10 deletions

View File

@@ -533,8 +533,7 @@ class ExllamaV2Container(BaseModelContainer):
# Load draft model if a config is present # Load draft model if a config is present
if self.draft_config: if self.draft_config:
self.draft_model = ExLlamaV2(self.draft_config) self.draft_model = ExLlamaV2(self.draft_config)
if not self.quiet: logger.info("Loading draft model: " + self.draft_config.model_dir)
logger.info("Loading draft model: " + self.draft_config.model_dir)
# Draft uses the autosplit loader, so create a cache that reflects this # Draft uses the autosplit loader, so create a cache that reflects this
draft_cache_class = self.get_cache_class(self.draft_cache_mode) draft_cache_class = self.get_cache_class(self.draft_cache_mode)
@@ -587,8 +586,7 @@ class ExllamaV2Container(BaseModelContainer):
yield value yield value
self.model = ExLlamaV2(self.config) self.model = ExLlamaV2(self.config)
if not self.quiet: logger.info("Loading model: " + self.config.model_dir)
logger.info("Loading model: " + self.config.model_dir)
# Get class of the model cache # Get class of the model cache
cache_class = self.get_cache_class(self.cache_mode) cache_class = self.get_cache_class(self.cache_mode)

View File

@@ -16,12 +16,12 @@ from backends.base_model_container import BaseModelContainer
from common.concurrency import iterate_in_threadpool from common.concurrency import iterate_in_threadpool
from common.multimodal import MultimodalEmbeddingWrapper from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate from common.templating import PromptTemplate, find_prompt_template
from common.transformers_utils import GenerationConfig from common.transformers_utils import GenerationConfig
from common.utils import unwrap from common.utils import unwrap
from endpoints.core.types.model import ModelCard from endpoints.core.types.model import ModelCard
from exllamav3 import Config, Model, Cache, Tokenizer from exllamav3 import AsyncGenerator, AsyncJob, Config, Model, Cache, Tokenizer
class ExllamaV3Container(BaseModelContainer): class ExllamaV3Container(BaseModelContainer):
@@ -46,6 +46,8 @@ class ExllamaV3Container(BaseModelContainer):
cache: Cache cache: Cache
tokenizer: Tokenizer tokenizer: Tokenizer
config: Config config: Config
gpu_split: List[float] = []
max_seq_len: int = 2048
# Required methods # Required methods
@classmethod @classmethod
@@ -74,6 +76,16 @@ class ExllamaV3Container(BaseModelContainer):
max_seq_len = kwargs.get("max_seq_len") max_seq_len = kwargs.get("max_seq_len")
self.cache = Cache(self.model, max_num_tokens=max_seq_len) self.cache = Cache(self.model, max_num_tokens=max_seq_len)
gpu_split = unwrap(kwargs.get("gpu_split"), [])
# Set GPU split options
# Enable manual GPU split if provided
if gpu_split:
self.gpu_split = gpu_split
# Try to set prompt template
self.prompt_template = await find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
return self return self
@@ -128,7 +140,10 @@ class ExllamaV3Container(BaseModelContainer):
# TODO: Add draft loading # TODO: Add draft loading
@torch.inference_mode() @torch.inference_mode()
def load_model_sync(self, progress_callback=None): def load_model_sync(self, progress_callback=None):
for value in self.model.load_gen(callback=progress_callback): for value in self.model.load_gen(
use_per_device=self.gpu_split,
callback=progress_callback
):
if value: if value:
yield value yield value
@@ -263,7 +278,58 @@ class ExllamaV3Container(BaseModelContainer):
A dictionary containing the generation info A dictionary containing the generation info
""" """
pass generations = []
async for generation in self.stream_generate(
request_id,
prompt,
params,
abort_event,
mm_embeddings,
):
generations.append(generation)
joined_generation = {
"text": "",
"prompt_tokens": 0,
"generation_tokens": 0,
"tool_calls": None,
"offset": [],
"token_probs": {},
"logprobs": [],
}
if generations:
# Get finish_reason first and then shift where -1 points to
if "finish_reason" in generations[-1]:
finish_reason_gen = generations.pop()
joined_generation["finish_reason"] = finish_reason_gen.get(
"finish_reason"
)
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
else:
joined_generation["finish_reason"] = "stop"
if len(generations) > 0:
for generation in generations:
joined_generation["text"] += unwrap(generation.get("text"), "")
joined_generation["offset"].append(unwrap(generation.get("offset"), -1))
joined_generation["token_probs"].update(
unwrap(generation.get("token_probs"), {})
)
# Include empty logprob dicts for index preservation
joined_generation["logprobs"].append(
unwrap(generation.get("logprobs"), {})
)
joined_generation["prompt_tokens"] = unwrap(
generations[-1].get("prompt_tokens"), 0
)
joined_generation["generated_tokens"] = unwrap(
generations[-1].get("generated_tokens"), 0
)
return joined_generation
async def stream_generate( async def stream_generate(
self, self,
@@ -287,5 +353,152 @@ class ExllamaV3Container(BaseModelContainer):
Generation chunks Generation chunks
""" """
if False: try:
yield # Wait for load lock to be freed before processing
# Mainly used for loras and other operations where the class is available
async with self.load_condition:
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
# If the model is being unloaded, don't accept new requests
if not self.loaded:
raise RuntimeError(
"Model is being unloaded. Cannot process new generation requests."
)
# Mark that the job is running
self.active_job_ids[request_id] = None
# Yield from the internal generator
async for generation_chunk in self.generate_gen(
request_id=request_id,
prompt=prompt,
params=params,
abort_event=abort_event,
mm_embeddings=mm_embeddings,
):
yield generation_chunk
finally:
# Clean up and remove the job from active IDs
del self.active_job_ids[request_id]
def handle_finish_chunk(self, result: dict, generation: dict):
eos_reason = result.get("eos_reason")
stop_str = None
if eos_reason == "max_new_tokens":
finish_reason = "length"
else:
finish_reason = "stop"
# Grab stop string if stop was the reason
if eos_reason == "stop_token":
stop_str = result.get("eos_triggering_token_str")
elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string")
finish_chunk = {
"prompt_tokens": generation.get("prompt_tokens"),
"generated_tokens": generation.get("generated_tokens"),
"finish_reason": finish_reason,
"stop_str": stop_str,
}
return finish_chunk
async def generate_gen(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
):
"""
Create generator function for prompt completion.
for kwargs, check common/sampling.py
"""
chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
prompts = [prompt]
stop_conditions = params.stop
add_bos_token = params.add_bos_token
# Fetch EOS tokens from generation_config if they exist
eos_tokens = (
self.generation_config.eos_tokens()
if self.generation_config
else [self.tokenizer.eos_token_id]
)
stop_conditions += eos_tokens
input_ids = [
self.tokenizer.encode(
prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
)
for prompt in prompts
]
# The first index will always be the positive prompt
context_len = input_ids[0].size(dim=-1)
# Automatically set max_tokens to fill up the context
# This should be an OK default, but may be changed in the future
max_tokens = unwrap(
params.max_tokens,
self.max_seq_len - context_len,
)
if max_tokens < 1:
logger.warning("max_tokens must be a positive integer, setting to 1.")
max_tokens = 1
# Determine if the negative context or the context length is bigger
context_to_check = context_len
# Check total length of prompt against max context length
if context_to_check > self.max_seq_len:
preamble = "Prompt"
raise ValueError(
f"{preamble} length {context_to_check} is greater than "
f"max_seq_len {self.max_seq_len}"
)
self.generator = AsyncGenerator(
model=self.model,
cache=self.cache,
tokenizer=self.tokenizer,
)
generation = {}
print(max_tokens)
job = AsyncJob(
self.generator,
input_ids=self.tokenizer.encode(prompt, add_bos=False),
max_new_tokens=max_tokens,
stop_conditions=stop_conditions,
)
generated_tokens = 0
full_response = ""
async for result in job:
chunk = unwrap(result.get("text"), "")
if chunk:
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))
full_response += chunk
if isinstance(chunk_tokens, torch.Tensor):
generated_tokens += chunk_tokens.size(dim=0)
generation = {
"text": chunk,
"prompt_tokens": context_len,
"generated_tokens": generated_tokens,
"offset": len(full_response),
}
yield generation
if result.get("eos"):
generation = self.handle_finish_chunk(result, generation)
yield generation
# Assign the active job to the request ID
self.active_job_ids[request_id] = job