Update prompt speed test

This commit is contained in:
turboderp
2024-04-18 09:19:25 +02:00
parent aef7bd125a
commit dc1dfc4dd5

View File

@@ -467,22 +467,36 @@ if args.prompt_speed:
print(f" -- Measuring prompt speed...")
torch.cuda.synchronize()
current_len = 128
step = 128
prompt_iters = 3
while True:
time_begin = time.time()
total_time = 0
for i in range(prompt_iters):
cache.current_seq_len = 0
model.forward(ids[:, :current_len], cache, preprocess_only = True)
torch.cuda.synchronize()
torch.cuda.synchronize()
time_begin = time.time()
time_end = time.time()
tps = current_len / (time_end - time_begin)
cache.current_seq_len = 0
model.forward(ids[:, :current_len], cache, preprocess_only = True)
torch.cuda.synchronize()
time_end = time.time()
total_time += time_end - time_begin
tps = current_len / (total_time / prompt_iters)
print(f" ** Length {current_len:>5} tokens: {tps:>11.4f} t/s")
if current_len >= 1024: step = 1024
if current_len >= 4096: step = 4096
if current_len >= 16384: step = 8192
current_len_ = current_len
current_len = min(current_len + 128, model.config.max_seq_len)
current_len = min(current_len + step, model.config.max_seq_len)
if current_len == current_len_: break