Instrumentation etc.

This commit is contained in:
turboderp
2023-12-10 17:36:40 +01:00
parent 3c43bad57f
commit d8b4efa8d4
3 changed files with 24 additions and 5 deletions

View File

@@ -6,6 +6,12 @@ from exllamav2.linear import ExLlamaV2Linear
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
from exllamav2 import ext
# catch_key = None
# def set_catch(key):
# global catch_key
# catch_key = key
class ExLlamaV2MLP(ExLlamaV2Module):
layer_idx: int
@@ -130,6 +136,10 @@ class ExLlamaV2MLP(ExLlamaV2Module):
self.down_proj.set_device_idx(idx)
def forward(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
# global catch_key
#
# if self.key == catch_key:
# return self.forward_torch(hidden_states, cache, attn_mask, intermediates, loras = loras)
if self.q_handle is None or intermediates:
return self.forward_torch(hidden_states, cache, attn_mask, intermediates, loras = loras)

View File

@@ -20,6 +20,8 @@ import torch.nn.functional as F
from conversion.tokenize import get_tokens
from conversion.quantize import list_live_tensors
# from exllamav2.mlp import set_catch
import sys
import json
@@ -188,6 +190,8 @@ if args.eval_dataset:
global logits, target_ids, log_probs, token_log_probs
global mean_log_prob, perplexity
# set_catch("model.layers.3")
logprob_sum = 0
logprob_count = 0
@@ -207,6 +211,10 @@ if args.eval_dataset:
logprob_sum += log_probs[0, 0, eval_tokens[i, j+1]]
logprob_count += 1
# mean_log_prob = logprob_sum / logprob_count
# perplexity = math.exp(-mean_log_prob)
# print(f" -- Token {j}: {perplexity:.4f}")
print()
mean_log_prob = logprob_sum / logprob_count

View File

@@ -2,7 +2,7 @@
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav2.model import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Linear
from exllamav2.model import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Linear
from exllamav2.tokenizer import ExLlamaV2Tokenizer
import argparse, os, math, time
import pandas, fastparquet
@@ -18,7 +18,7 @@ with torch.inference_mode():
config_full = ExLlamaV2Config()
# config_full.model_dir = "/mnt/str/models/llama-7b"
config_full.model_dir = "/mnt/str/models/_exl2/llama2-7b"
config_full.model_dir = "/mnt/str/models/_exl2/tiefighter-13b/"
config_full.prepare()
model_full = ExLlamaV2(config_full)
model_full.load(lazy = True)
@@ -29,7 +29,7 @@ with torch.inference_mode():
config_quant = ExLlamaV2Config()
# config_quant.model_dir = "/mnt/str/models/_exl2/llama-7b-4.0bpw-h6-exl2/"
config_quant.model_dir = "/mnt/str/models/_exl2/llama2-7b-5.0bpw-h6-exl2/"
config_quant.model_dir = "/mnt/str/models/_exl2/tiefighter-13b-exl3/4.0bpw/"
# config_quant.model_dir = "/mnt/str/models/llama-7b-4bit-128g/"
# config_quant.model_dir = "/mnt/str/models/_test_models/TheBloke_WizardLM-30B-Uncensored-GPTQ/"
config_quant.prepare()
@@ -129,7 +129,7 @@ with torch.inference_mode():
# Load all matrices in a full layer of the quant model
target_layer = 4
target_layer = 3
prefix = f"layers.{target_layer}."
for k in model_quant.modules_dict.keys():
@@ -149,7 +149,8 @@ with torch.inference_mode():
module_quant.load()
if isinstance(module_quant, ExLlamaV2Linear):
gi = module_quant.dump_group_info()
# gi = module_quant.dump_group_info()
gi = "-----"
mat = torch.eye(module_quant.in_features, dtype = torch.half).cuda()
test1 = module_quant.forward(mat, force_cuda = True)