mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-27 09:44:13 +00:00
TP mode for attn layer, non-paged
This commit is contained in:
@@ -7,6 +7,7 @@ from exllamav2 import(
|
||||
ExLlamaV2Cache_Q4,
|
||||
ExLlamaV2Cache_Q6,
|
||||
ExLlamaV2Cache_Q8,
|
||||
ExLlamaV2Cache_TP,
|
||||
ExLlamaV2Tokenizer,
|
||||
model_init,
|
||||
)
|
||||
@@ -96,11 +97,14 @@ if args.stream_layers:
|
||||
|
||||
model_init.check_args(args)
|
||||
model_init.print_options(args)
|
||||
model, tokenizer = model_init.init(args,
|
||||
allow_auto_split = True,
|
||||
skip_load = args.stream_layers,
|
||||
benchmark = True,
|
||||
max_output_len = args.max_output_len)
|
||||
model, tokenizer = model_init.init(
|
||||
args,
|
||||
allow_auto_split = True,
|
||||
skip_load = args.stream_layers,
|
||||
benchmark = True,
|
||||
max_output_len = args.max_output_len,
|
||||
progress = True
|
||||
)
|
||||
cache = None
|
||||
|
||||
# Auto split
|
||||
@@ -113,7 +117,7 @@ if not model.loaded and not args.stream_layers:
|
||||
print(" -- Loading model...")
|
||||
cache = ExLlamaV2Cache(model, lazy = True)
|
||||
t = time.time()
|
||||
model.load_autosplit(cache)
|
||||
model.load_autosplit(cache, progress = True)
|
||||
t = time.time() - t
|
||||
print(f" -- Loaded model in {t:.4f} seconds")
|
||||
|
||||
@@ -185,7 +189,7 @@ if args.prompt:
|
||||
with torch.inference_mode():
|
||||
|
||||
if cache is None:
|
||||
cache = ExLlamaV2Cache(model)
|
||||
cache = ExLlamaV2Cache(model) if not model.tp_context else ExLlamaV2Cache_TP(model)
|
||||
|
||||
ids = tokenizer.encode(args.prompt)
|
||||
tokens_prompt = ids.shape[-1]
|
||||
@@ -292,7 +296,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
|
||||
def ppl(input_ids__, logits__, lengths__, bins = False):
|
||||
|
||||
logits_device = model.modules[-1].device()
|
||||
logits_device = model.modules[-1].device() if not model.tp_context else \
|
||||
torch.device(model.tp_context.device)
|
||||
|
||||
if bins:
|
||||
num_bins = (max(lengths__) + 255) // 256
|
||||
@@ -389,7 +394,10 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
sys.stdout.flush()
|
||||
|
||||
if cache is None:
|
||||
cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if eval_length > model.config.max_input_len else None
|
||||
if eval_length > model.config.max_input_len:
|
||||
cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if not model.tp_context else ExLlamaV2Cache_TP(model, max_seq_len = eval_length)
|
||||
else:
|
||||
cache = None
|
||||
|
||||
for i in range(eval_tokens.shape[0]):
|
||||
|
||||
@@ -470,7 +478,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
else:
|
||||
print(f" -- Inference (token)", end = "")
|
||||
sys.stdout.flush()
|
||||
cache = ExLlamaV2Cache(model, max_seq_len = eval_length)
|
||||
cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if not model.tp_context else \
|
||||
ExLlamaV2Cache_TP(model, max_seq_len = eval_length)
|
||||
test_ppl_token()
|
||||
|
||||
if args.eval_token_8bit:
|
||||
@@ -479,7 +488,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
else:
|
||||
print(f" -- Inference (token, 8-bit cache)", end = "")
|
||||
sys.stdout.flush()
|
||||
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length)
|
||||
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length) if not model.tp_context else \
|
||||
ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_8bit)
|
||||
test_ppl_token()
|
||||
|
||||
if args.eval_token_q4:
|
||||
@@ -488,7 +498,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
else:
|
||||
print(f" -- Inference (token, Q4 cache)", end = "")
|
||||
sys.stdout.flush()
|
||||
cache = ExLlamaV2Cache_Q4(model, max_seq_len = eval_length)
|
||||
cache = ExLlamaV2Cache_Q4(model, max_seq_len = eval_length) if not model.tp_context else \
|
||||
ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_Q4)
|
||||
# cache.calibrate(tokenizer)
|
||||
test_ppl_token()
|
||||
|
||||
@@ -498,7 +509,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
else:
|
||||
print(f" -- Inference (token, Q6 cache)", end = "")
|
||||
sys.stdout.flush()
|
||||
cache = ExLlamaV2Cache_Q6(model, max_seq_len = eval_length)
|
||||
cache = ExLlamaV2Cache_Q6(model, max_seq_len = eval_length) if not model.tp_context else \
|
||||
ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_Q6)
|
||||
# cache.calibrate(tokenizer)
|
||||
test_ppl_token()
|
||||
|
||||
@@ -508,7 +520,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
else:
|
||||
print(f" -- Inference (token, Q8 cache)", end = "")
|
||||
sys.stdout.flush()
|
||||
cache = ExLlamaV2Cache_Q8(model, max_seq_len = eval_length)
|
||||
cache = ExLlamaV2Cache_Q8(model, max_seq_len = eval_length) if not model.tp_context else \
|
||||
ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_Q8)
|
||||
# cache.calibrate(tokenizer)
|
||||
test_ppl_token()
|
||||
|
||||
@@ -520,7 +533,7 @@ if args.prompt_speed:
|
||||
with torch.inference_mode():
|
||||
|
||||
if cache is None:
|
||||
cache = ExLlamaV2Cache(model)
|
||||
cache = ExLlamaV2Cache(model) if not model.tp_context else ExLlamaV2Cache_TP(model)
|
||||
|
||||
ids = torch.randint(0, model.config.vocab_size - 1, (1, model.config.max_seq_len))
|
||||
|
||||
@@ -571,7 +584,7 @@ if args.speed:
|
||||
with torch.inference_mode():
|
||||
|
||||
if cache is None:
|
||||
cache = ExLlamaV2Cache(model)
|
||||
cache = ExLlamaV2Cache(model) if not model.tp_context else ExLlamaV2Cache_TP(model)
|
||||
cache.current_seq_len = 0
|
||||
|
||||
print(f" -- Measuring token speed...")
|
||||
|
||||
Reference in New Issue
Block a user