Add some debug functions

This commit is contained in:
turboderp
2024-10-02 23:58:33 +02:00
parent b651f4abab
commit ed6dc9b7b3
4 changed files with 28 additions and 3 deletions

View File

@@ -6,7 +6,7 @@ from exllamav2.generator.filters import ExLlamaV2Filter
from exllamav2.cache import ExLlamaV2CacheBase, ExLlamaV2Cache_8bit
from exllamav2.attn import ExLlamaV2Attention, assert_paged_attn
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
from exllamav2.util import cuda_sync_active
from exllamav2.util import cuda_sync_active, timed
from concurrent.futures import ThreadPoolExecutor
from exllamav2.compat import pairwise
@@ -525,7 +525,7 @@ class ExLlamaV2DynamicGenerator:
self.current_loras = loras
else:
self.current_loras = [loras]
def generate(
self,
@@ -667,6 +667,7 @@ class ExLlamaV2DynamicGenerator:
while self.num_remaining_jobs():
results = self.iterate()
for r in results:
idx = order[r["serial"]]
if r["stage"] == "streaming":

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from threading import Lock
from concurrent.futures import ThreadPoolExecutor, Future
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.util import timed
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
import torch

View File

@@ -53,7 +53,7 @@ import gc
import threading
from typing import Callable
# from exllamav2.util import list_live_tensors, print_vram_usage, set_snapshot, diff_snapshot, print_vram_usage_peak
from exllamav2.util import get_basic_progress
from exllamav2.util import get_basic_progress, timed
# from line_profiler import profile
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from rich.progress import Progress, BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
import gc, subprocess, time, os, json
from collections import deque
import torch
@@ -14,6 +15,28 @@ class Timer:
self.interval = self.end_time - self.start_time
timings: dict[str: deque[float]] = {}
num_run_avg = 10
def timed(func):
def wrapper(*args, **kwargs):
global timings, num_run_avg
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
elapsed_time = end_time - start_time
fname = func.__name__
if fname not in timings:
timings[fname] = deque()
if len(timings[fname]) >= num_run_avg:
timings[fname].popleft()
timings[fname].append(elapsed_time)
avg = sum(timings[fname]) / len(timings[fname])
print(f"{fname} executed in {elapsed_time:.4f} seconds, running avg. ({num_run_avg}): {avg:.4f}")
return result
return wrapper
class SeqTensor:
PAGE_SIZE = 256