mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-03-15 00:07:26 +00:00
Instrumentation etc.
This commit is contained in:
@@ -6,6 +6,12 @@ from exllamav2.linear import ExLlamaV2Linear
|
||||
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
|
||||
from exllamav2 import ext
|
||||
|
||||
# catch_key = None
|
||||
# def set_catch(key):
|
||||
# global catch_key
|
||||
# catch_key = key
|
||||
|
||||
|
||||
class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
|
||||
layer_idx: int
|
||||
@@ -130,6 +136,10 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
self.down_proj.set_device_idx(idx)
|
||||
|
||||
def forward(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
|
||||
# global catch_key
|
||||
#
|
||||
# if self.key == catch_key:
|
||||
# return self.forward_torch(hidden_states, cache, attn_mask, intermediates, loras = loras)
|
||||
|
||||
if self.q_handle is None or intermediates:
|
||||
return self.forward_torch(hidden_states, cache, attn_mask, intermediates, loras = loras)
|
||||
|
||||
@@ -20,6 +20,8 @@ import torch.nn.functional as F
|
||||
from conversion.tokenize import get_tokens
|
||||
from conversion.quantize import list_live_tensors
|
||||
|
||||
# from exllamav2.mlp import set_catch
|
||||
|
||||
import sys
|
||||
import json
|
||||
|
||||
@@ -188,6 +190,8 @@ if args.eval_dataset:
|
||||
global logits, target_ids, log_probs, token_log_probs
|
||||
global mean_log_prob, perplexity
|
||||
|
||||
# set_catch("model.layers.3")
|
||||
|
||||
logprob_sum = 0
|
||||
logprob_count = 0
|
||||
|
||||
@@ -207,6 +211,10 @@ if args.eval_dataset:
|
||||
logprob_sum += log_probs[0, 0, eval_tokens[i, j+1]]
|
||||
logprob_count += 1
|
||||
|
||||
# mean_log_prob = logprob_sum / logprob_count
|
||||
# perplexity = math.exp(-mean_log_prob)
|
||||
# print(f" -- Token {j}: {perplexity:.4f}")
|
||||
|
||||
print()
|
||||
|
||||
mean_log_prob = logprob_sum / logprob_count
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import sys, os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from exllamav2.model import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Linear
|
||||
from exllamav2.model import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Linear
|
||||
from exllamav2.tokenizer import ExLlamaV2Tokenizer
|
||||
import argparse, os, math, time
|
||||
import pandas, fastparquet
|
||||
@@ -18,7 +18,7 @@ with torch.inference_mode():
|
||||
|
||||
config_full = ExLlamaV2Config()
|
||||
# config_full.model_dir = "/mnt/str/models/llama-7b"
|
||||
config_full.model_dir = "/mnt/str/models/_exl2/llama2-7b"
|
||||
config_full.model_dir = "/mnt/str/models/_exl2/tiefighter-13b/"
|
||||
config_full.prepare()
|
||||
model_full = ExLlamaV2(config_full)
|
||||
model_full.load(lazy = True)
|
||||
@@ -29,7 +29,7 @@ with torch.inference_mode():
|
||||
|
||||
config_quant = ExLlamaV2Config()
|
||||
# config_quant.model_dir = "/mnt/str/models/_exl2/llama-7b-4.0bpw-h6-exl2/"
|
||||
config_quant.model_dir = "/mnt/str/models/_exl2/llama2-7b-5.0bpw-h6-exl2/"
|
||||
config_quant.model_dir = "/mnt/str/models/_exl2/tiefighter-13b-exl3/4.0bpw/"
|
||||
# config_quant.model_dir = "/mnt/str/models/llama-7b-4bit-128g/"
|
||||
# config_quant.model_dir = "/mnt/str/models/_test_models/TheBloke_WizardLM-30B-Uncensored-GPTQ/"
|
||||
config_quant.prepare()
|
||||
@@ -129,7 +129,7 @@ with torch.inference_mode():
|
||||
|
||||
# Load all matrices in a full layer of the quant model
|
||||
|
||||
target_layer = 4
|
||||
target_layer = 3
|
||||
prefix = f"layers.{target_layer}."
|
||||
|
||||
for k in model_quant.modules_dict.keys():
|
||||
@@ -149,7 +149,8 @@ with torch.inference_mode():
|
||||
module_quant.load()
|
||||
if isinstance(module_quant, ExLlamaV2Linear):
|
||||
|
||||
gi = module_quant.dump_group_info()
|
||||
# gi = module_quant.dump_group_info()
|
||||
gi = "-----"
|
||||
|
||||
mat = torch.eye(module_quant.in_features, dtype = torch.half).cuda()
|
||||
test1 = module_quant.forward(mat, force_cuda = True)
|
||||
|
||||
Reference in New Issue
Block a user