From d8b4efa8d444888f4b477fc9da2f33246b92e723 Mon Sep 17 00:00:00 2001 From: turboderp Date: Sun, 10 Dec 2023 17:36:40 +0100 Subject: [PATCH] Instrumentation etc. --- exllamav2/mlp.py | 10 ++++++++++ test_inference.py | 8 ++++++++ tests/test_gemv.py | 11 ++++++----- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/exllamav2/mlp.py b/exllamav2/mlp.py index d3b8b19..a913cfe 100644 --- a/exllamav2/mlp.py +++ b/exllamav2/mlp.py @@ -6,6 +6,12 @@ from exllamav2.linear import ExLlamaV2Linear from exllamav2.ext import exllamav2_ext as ext_c, none_tensor from exllamav2 import ext +# catch_key = None +# def set_catch(key): +# global catch_key +# catch_key = key + + class ExLlamaV2MLP(ExLlamaV2Module): layer_idx: int @@ -130,6 +136,10 @@ class ExLlamaV2MLP(ExLlamaV2Module): self.down_proj.set_device_idx(idx) def forward(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, position_offsets = None): + # global catch_key + # + # if self.key == catch_key: + # return self.forward_torch(hidden_states, cache, attn_mask, intermediates, loras = loras) if self.q_handle is None or intermediates: return self.forward_torch(hidden_states, cache, attn_mask, intermediates, loras = loras) diff --git a/test_inference.py b/test_inference.py index 5afd004..fe7e5ec 100644 --- a/test_inference.py +++ b/test_inference.py @@ -20,6 +20,8 @@ import torch.nn.functional as F from conversion.tokenize import get_tokens from conversion.quantize import list_live_tensors +# from exllamav2.mlp import set_catch + import sys import json @@ -188,6 +190,8 @@ if args.eval_dataset: global logits, target_ids, log_probs, token_log_probs global mean_log_prob, perplexity + # set_catch("model.layers.3") + logprob_sum = 0 logprob_count = 0 @@ -207,6 +211,10 @@ if args.eval_dataset: logprob_sum += log_probs[0, 0, eval_tokens[i, j+1]] logprob_count += 1 + # mean_log_prob = logprob_sum / logprob_count + # perplexity = math.exp(-mean_log_prob) + # print(f" -- Token {j}: {perplexity:.4f}") + print() mean_log_prob = logprob_sum / logprob_count diff --git a/tests/test_gemv.py b/tests/test_gemv.py index 8689f3c..d87786b 100644 --- a/tests/test_gemv.py +++ b/tests/test_gemv.py @@ -2,7 +2,7 @@ import sys, os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from exllamav2.model import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Linear +from exllamav2.model import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Linear from exllamav2.tokenizer import ExLlamaV2Tokenizer import argparse, os, math, time import pandas, fastparquet @@ -18,7 +18,7 @@ with torch.inference_mode(): config_full = ExLlamaV2Config() # config_full.model_dir = "/mnt/str/models/llama-7b" - config_full.model_dir = "/mnt/str/models/_exl2/llama2-7b" + config_full.model_dir = "/mnt/str/models/_exl2/tiefighter-13b/" config_full.prepare() model_full = ExLlamaV2(config_full) model_full.load(lazy = True) @@ -29,7 +29,7 @@ with torch.inference_mode(): config_quant = ExLlamaV2Config() # config_quant.model_dir = "/mnt/str/models/_exl2/llama-7b-4.0bpw-h6-exl2/" - config_quant.model_dir = "/mnt/str/models/_exl2/llama2-7b-5.0bpw-h6-exl2/" + config_quant.model_dir = "/mnt/str/models/_exl2/tiefighter-13b-exl3/4.0bpw/" # config_quant.model_dir = "/mnt/str/models/llama-7b-4bit-128g/" # config_quant.model_dir = "/mnt/str/models/_test_models/TheBloke_WizardLM-30B-Uncensored-GPTQ/" config_quant.prepare() @@ -129,7 +129,7 @@ with torch.inference_mode(): # Load all matrices in a full layer of the quant model - target_layer = 4 + target_layer = 3 prefix = f"layers.{target_layer}." for k in model_quant.modules_dict.keys(): @@ -149,7 +149,8 @@ with torch.inference_mode(): module_quant.load() if isinstance(module_quant, ExLlamaV2Linear): - gi = module_quant.dump_group_info() + # gi = module_quant.dump_group_info() + gi = "-----" mat = torch.eye(module_quant.in_features, dtype = torch.half).cuda() test1 = module_quant.forward(mat, force_cuda = True)