mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Update prompt speed test
This commit is contained in:
@@ -467,22 +467,36 @@ if args.prompt_speed:
|
||||
|
||||
print(f" -- Measuring prompt speed...")
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
current_len = 128
|
||||
step = 128
|
||||
prompt_iters = 3
|
||||
while True:
|
||||
|
||||
time_begin = time.time()
|
||||
total_time = 0
|
||||
for i in range(prompt_iters):
|
||||
|
||||
cache.current_seq_len = 0
|
||||
model.forward(ids[:, :current_len], cache, preprocess_only = True)
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
time_begin = time.time()
|
||||
|
||||
time_end = time.time()
|
||||
tps = current_len / (time_end - time_begin)
|
||||
cache.current_seq_len = 0
|
||||
model.forward(ids[:, :current_len], cache, preprocess_only = True)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
time_end = time.time()
|
||||
total_time += time_end - time_begin
|
||||
|
||||
tps = current_len / (total_time / prompt_iters)
|
||||
|
||||
print(f" ** Length {current_len:>5} tokens: {tps:>11.4f} t/s")
|
||||
|
||||
if current_len >= 1024: step = 1024
|
||||
if current_len >= 4096: step = 4096
|
||||
if current_len >= 16384: step = 8192
|
||||
|
||||
current_len_ = current_len
|
||||
current_len = min(current_len + 128, model.config.max_seq_len)
|
||||
current_len = min(current_len + step, model.config.max_seq_len)
|
||||
if current_len == current_len_: break
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user