mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-26 17:28:54 +00:00
fixup: autosplit, start work on metrics
This commit is contained in:
@@ -1,8 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
import math
|
|
||||||
import pathlib
|
import pathlib
|
||||||
from loguru import logger
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
@@ -12,9 +10,14 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from exllamav3 import AsyncGenerator, AsyncJob, Cache, Config, Model, Tokenizer
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from backends.base_model_container import BaseModelContainer
|
from backends.base_model_container import BaseModelContainer
|
||||||
from common.concurrency import iterate_in_threadpool
|
from common.concurrency import iterate_in_threadpool
|
||||||
|
from common.gen_logging import (
|
||||||
|
log_metrics,
|
||||||
|
)
|
||||||
from common.multimodal import MultimodalEmbeddingWrapper
|
from common.multimodal import MultimodalEmbeddingWrapper
|
||||||
from common.sampling import BaseSamplerRequest
|
from common.sampling import BaseSamplerRequest
|
||||||
from common.templating import PromptTemplate, find_prompt_template
|
from common.templating import PromptTemplate, find_prompt_template
|
||||||
@@ -22,8 +25,6 @@ from common.transformers_utils import GenerationConfig
|
|||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
from endpoints.core.types.model import ModelCard
|
from endpoints.core.types.model import ModelCard
|
||||||
|
|
||||||
from exllamav3 import AsyncGenerator, AsyncJob, Config, Model, Cache, Tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class ExllamaV3Container(BaseModelContainer):
|
class ExllamaV3Container(BaseModelContainer):
|
||||||
"""Abstract base class for model containers."""
|
"""Abstract base class for model containers."""
|
||||||
@@ -112,7 +113,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||||||
|
|
||||||
# Reserve VRAM for each GPU
|
# Reserve VRAM for each GPU
|
||||||
self.autosplit_reserve = [
|
self.autosplit_reserve = [
|
||||||
int(math.ceil(value/1024))
|
value/1024
|
||||||
for value in autosplit_reserve_megabytes
|
for value in autosplit_reserve_megabytes
|
||||||
]
|
]
|
||||||
# TODO: speculative decoding
|
# TODO: speculative decoding
|
||||||
@@ -504,15 +505,17 @@ class ExllamaV3Container(BaseModelContainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
generation = {}
|
generation = {}
|
||||||
print(max_tokens)
|
|
||||||
job = AsyncJob(
|
job = AsyncJob(
|
||||||
self.generator,
|
self.generator,
|
||||||
input_ids=self.tokenizer.encode(prompt, add_bos=False),
|
input_ids=self.tokenizer.encode(prompt, add_bos=False),
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
stop_conditions=stop_conditions,
|
stop_conditions=stop_conditions,
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_tokens = 0
|
generated_tokens = 0
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
metrics_result = {}
|
||||||
|
|
||||||
async for result in job:
|
async for result in job:
|
||||||
chunk = unwrap(result.get("text"), "")
|
chunk = unwrap(result.get("text"), "")
|
||||||
if chunk:
|
if chunk:
|
||||||
@@ -530,6 +533,25 @@ class ExllamaV3Container(BaseModelContainer):
|
|||||||
|
|
||||||
if result.get("eos"):
|
if result.get("eos"):
|
||||||
generation = self.handle_finish_chunk(result, generation)
|
generation = self.handle_finish_chunk(result, generation)
|
||||||
|
|
||||||
|
# Save the final result for metrics logging
|
||||||
|
metrics_result = result
|
||||||
|
|
||||||
yield generation
|
yield generation
|
||||||
|
break
|
||||||
# Assign the active job to the request ID
|
# Assign the active job to the request ID
|
||||||
self.active_job_ids[request_id] = job
|
self.active_job_ids[request_id] = job
|
||||||
|
|
||||||
|
# Log the metrics if present
|
||||||
|
if metrics_result:
|
||||||
|
log_metrics(
|
||||||
|
request_id,
|
||||||
|
metrics_result.get("time_enqueued"),
|
||||||
|
metrics_result.get("prompt_tokens"),
|
||||||
|
metrics_result.get("cached_tokens"),
|
||||||
|
metrics_result.get("time_prefill"),
|
||||||
|
metrics_result.get("new_tokens"),
|
||||||
|
metrics_result.get("time_generate"),
|
||||||
|
context_len,
|
||||||
|
self.max_seq_len,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user