mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 06:19:00 +00:00
Add some debug functions
This commit is contained in:
@@ -6,7 +6,7 @@ from exllamav2.generator.filters import ExLlamaV2Filter
|
||||
from exllamav2.cache import ExLlamaV2CacheBase, ExLlamaV2Cache_8bit
|
||||
from exllamav2.attn import ExLlamaV2Attention, assert_paged_attn
|
||||
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
|
||||
from exllamav2.util import cuda_sync_active
|
||||
from exllamav2.util import cuda_sync_active, timed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from exllamav2.compat import pairwise
|
||||
@@ -525,7 +525,7 @@ class ExLlamaV2DynamicGenerator:
|
||||
self.current_loras = loras
|
||||
else:
|
||||
self.current_loras = [loras]
|
||||
|
||||
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@@ -667,6 +667,7 @@ class ExLlamaV2DynamicGenerator:
|
||||
|
||||
while self.num_remaining_jobs():
|
||||
results = self.iterate()
|
||||
|
||||
for r in results:
|
||||
idx = order[r["serial"]]
|
||||
if r["stage"] == "streaming":
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
from threading import Lock
|
||||
from concurrent.futures import ThreadPoolExecutor, Future
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||
from exllamav2.util import timed
|
||||
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
|
||||
import torch
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ import gc
|
||||
import threading
|
||||
from typing import Callable
|
||||
# from exllamav2.util import list_live_tensors, print_vram_usage, set_snapshot, diff_snapshot, print_vram_usage_peak
|
||||
from exllamav2.util import get_basic_progress
|
||||
from exllamav2.util import get_basic_progress, timed
|
||||
# from line_profiler import profile
|
||||
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
from rich.progress import Progress, BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
|
||||
import gc, subprocess, time, os, json
|
||||
from collections import deque
|
||||
import torch
|
||||
|
||||
|
||||
@@ -14,6 +15,28 @@ class Timer:
|
||||
self.interval = self.end_time - self.start_time
|
||||
|
||||
|
||||
timings: dict[str: deque[float]] = {}
|
||||
num_run_avg = 10
|
||||
|
||||
def timed(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
global timings, num_run_avg
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
fname = func.__name__
|
||||
if fname not in timings:
|
||||
timings[fname] = deque()
|
||||
if len(timings[fname]) >= num_run_avg:
|
||||
timings[fname].popleft()
|
||||
timings[fname].append(elapsed_time)
|
||||
avg = sum(timings[fname]) / len(timings[fname])
|
||||
print(f"{fname} executed in {elapsed_time:.4f} seconds, running avg. ({num_run_avg}): {avg:.4f}")
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
class SeqTensor:
|
||||
|
||||
PAGE_SIZE = 256
|
||||
|
||||
Reference in New Issue
Block a user