TP mode for attn layer, non-paged

This commit is contained in:
turboderp
2024-08-14 23:41:10 +02:00
parent 65b9e17c4f
commit b30f796690
6 changed files with 193 additions and 27 deletions

View File

@@ -7,6 +7,7 @@ from exllamav2 import(
ExLlamaV2Cache_Q4,
ExLlamaV2Cache_Q6,
ExLlamaV2Cache_Q8,
ExLlamaV2Cache_TP,
ExLlamaV2Tokenizer,
model_init,
)
@@ -96,11 +97,14 @@ if args.stream_layers:
model_init.check_args(args)
model_init.print_options(args)
model, tokenizer = model_init.init(args,
allow_auto_split = True,
skip_load = args.stream_layers,
benchmark = True,
max_output_len = args.max_output_len)
model, tokenizer = model_init.init(
args,
allow_auto_split = True,
skip_load = args.stream_layers,
benchmark = True,
max_output_len = args.max_output_len,
progress = True
)
cache = None
# Auto split
@@ -113,7 +117,7 @@ if not model.loaded and not args.stream_layers:
print(" -- Loading model...")
cache = ExLlamaV2Cache(model, lazy = True)
t = time.time()
model.load_autosplit(cache)
model.load_autosplit(cache, progress = True)
t = time.time() - t
print(f" -- Loaded model in {t:.4f} seconds")
@@ -185,7 +189,7 @@ if args.prompt:
with torch.inference_mode():
if cache is None:
cache = ExLlamaV2Cache(model)
cache = ExLlamaV2Cache(model) if not model.tp_context else ExLlamaV2Cache_TP(model)
ids = tokenizer.encode(args.prompt)
tokens_prompt = ids.shape[-1]
@@ -292,7 +296,8 @@ if args.eval_dataset or args.standard_perplexity:
def ppl(input_ids__, logits__, lengths__, bins = False):
logits_device = model.modules[-1].device()
logits_device = model.modules[-1].device() if not model.tp_context else \
torch.device(model.tp_context.device)
if bins:
num_bins = (max(lengths__) + 255) // 256
@@ -389,7 +394,10 @@ if args.eval_dataset or args.standard_perplexity:
sys.stdout.flush()
if cache is None:
cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if eval_length > model.config.max_input_len else None
if eval_length > model.config.max_input_len:
cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if not model.tp_context else ExLlamaV2Cache_TP(model, max_seq_len = eval_length)
else:
cache = None
for i in range(eval_tokens.shape[0]):
@@ -470,7 +478,8 @@ if args.eval_dataset or args.standard_perplexity:
else:
print(f" -- Inference (token)", end = "")
sys.stdout.flush()
cache = ExLlamaV2Cache(model, max_seq_len = eval_length)
cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if not model.tp_context else \
ExLlamaV2Cache_TP(model, max_seq_len = eval_length)
test_ppl_token()
if args.eval_token_8bit:
@@ -479,7 +488,8 @@ if args.eval_dataset or args.standard_perplexity:
else:
print(f" -- Inference (token, 8-bit cache)", end = "")
sys.stdout.flush()
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length)
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length) if not model.tp_context else \
ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_8bit)
test_ppl_token()
if args.eval_token_q4:
@@ -488,7 +498,8 @@ if args.eval_dataset or args.standard_perplexity:
else:
print(f" -- Inference (token, Q4 cache)", end = "")
sys.stdout.flush()
cache = ExLlamaV2Cache_Q4(model, max_seq_len = eval_length)
cache = ExLlamaV2Cache_Q4(model, max_seq_len = eval_length) if not model.tp_context else \
ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_Q4)
# cache.calibrate(tokenizer)
test_ppl_token()
@@ -498,7 +509,8 @@ if args.eval_dataset or args.standard_perplexity:
else:
print(f" -- Inference (token, Q6 cache)", end = "")
sys.stdout.flush()
cache = ExLlamaV2Cache_Q6(model, max_seq_len = eval_length)
cache = ExLlamaV2Cache_Q6(model, max_seq_len = eval_length) if not model.tp_context else \
ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_Q6)
# cache.calibrate(tokenizer)
test_ppl_token()
@@ -508,7 +520,8 @@ if args.eval_dataset or args.standard_perplexity:
else:
print(f" -- Inference (token, Q8 cache)", end = "")
sys.stdout.flush()
cache = ExLlamaV2Cache_Q8(model, max_seq_len = eval_length)
cache = ExLlamaV2Cache_Q8(model, max_seq_len = eval_length) if not model.tp_context else \
ExLlamaV2Cache_TP(model, max_seq_len = eval_length, base = ExLlamaV2Cache_Q8)
# cache.calibrate(tokenizer)
test_ppl_token()
@@ -520,7 +533,7 @@ if args.prompt_speed:
with torch.inference_mode():
if cache is None:
cache = ExLlamaV2Cache(model)
cache = ExLlamaV2Cache(model) if not model.tp_context else ExLlamaV2Cache_TP(model)
ids = torch.randint(0, model.config.vocab_size - 1, (1, model.config.max_seq_len))
@@ -571,7 +584,7 @@ if args.speed:
with torch.inference_mode():
if cache is None:
cache = ExLlamaV2Cache(model)
cache = ExLlamaV2Cache(model) if not model.tp_context else ExLlamaV2Cache_TP(model)
cache.current_seq_len = 0
print(f" -- Measuring token speed...")