compare_q.py: Add QTIP wrapper

This commit is contained in:
turboderp
2025-06-08 15:41:30 +02:00
parent f02c9afd6a
commit 32d98c24c1
4 changed files with 104 additions and 0 deletions

View File

@@ -42,6 +42,10 @@ from compare_q_anyprecision import (
load_anyprecision,
fwd_anyprecision,
)
from compare_q_qtip import (
load_qtip,
fwd_qtip,
)
load_fns = {
"transformers_auto_bf16": load_transformers_auto_bf16,
@@ -51,6 +55,7 @@ load_fns = {
"exllamav3": load_exllamav3,
"llamacpp": load_llamacpp,
"anyprecision": load_anyprecision,
"qtip": load_qtip,
}
fwd_fns = {
@@ -59,6 +64,7 @@ fwd_fns = {
"exllamav3": fwd_exllamav3,
"llamacpp": fwd_llamacpp,
"anyprecision": fwd_anyprecision,
"qtip": fwd_qtip,
}
tokenize_fns = {
@@ -230,6 +236,7 @@ def plot(results, args):
"imat": "brown",
"GGUF": "red",
"VPTQ": "blue",
"QTIP": "teal",
"****": "black",
}
for k, v in d.items():

57
eval/compare_q_qtip.py Normal file
View File

@@ -0,0 +1,57 @@
import torch
"""
Very kludgy wrapper for a custom, hacky install of the QTIP repo as a package. Don't expect this to work generally.
Uses the QTIP package for inference so only supports Llama.
"""
try:
import qtip
from qtip.lib.linear.quantized_linear import QuantizedLinear
from qtip.lib.utils.unsafe_import import model_from_hf_path
except ModuleNotFoundError:
pass
def get_tensors_size(tensors):
return 8 * sum(t.element_size() * t.numel() for t in tensors.values() if t is not None)
def get_tensor_size(tensor):
return 8 * tensor.element_size() * tensor.numel()
def get_storage_info(model):
sum_bits = 0
sum_numel = 0
head_bpw = 0
head_numel = 0
for name, module in model.named_modules():
if any(isinstance(module, x) for x in [torch.nn.Linear]):
if module.out_features >= model.vocab_size * 0.9:
head_bpw = module.weight.element_size() * 8
head_numel = module.weight.numel()
else:
sum_bits += get_tensor_size(module.weight)
sum_numel += module.weight.numel()
elif any(isinstance(module, x) for x in [QuantizedLinear]):
sum_bits += get_tensors_size({
"SU": module.SU,
"SV": module.SV,
"tlut": module.tlut,
"trellis": module.trellis,
})
sum_numel += module.in_features * module.out_features
vram_bits = head_numel * head_bpw + sum_bits
return sum_bits / sum_numel, head_bpw, vram_bits
@torch.inference_mode
@torch.compiler.disable
def load_qtip(model_dir: str, auto = False, bf16 = False):
model, model_str = model_from_hf_path(model_dir, max_mem_ratio = 0.7)
bpw_layer, bpw_head, vram_bits = get_storage_info(model)
return model, bpw_layer, bpw_head, vram_bits
@torch.inference_mode
@torch.compiler.disable
def fwd_qtip(model_instance, input_ids: torch.Tensor):
input_ids = input_ids.to("cuda:0")
output = model_instance(input_ids)
return output.logits

View File

@@ -0,0 +1,20 @@
[
{
"model_dir": "/mnt/str/models/llama3.1-70b-instruct/qtip/2bit-hyb-yaqa",
"load_fn": "qtip",
"fwd_fn": "qtip",
"label": "QTIP 2bit-HYB-YAQA"
},
{
"model_dir": "/mnt/str/models/llama3.1-70b-instruct/qtip/3bit-hyb-yaqa",
"load_fn": "qtip",
"fwd_fn": "qtip",
"label": "QTIP 3bit-HYB-YAQA"
},
{
"model_dir": "/mnt/str/models/llama3.1-70b-instruct/qtip/4bit-hyb-yaqa",
"load_fn": "qtip",
"fwd_fn": "qtip",
"label": "QTIP 4bit-HYB-YAQA"
},
]

View File

@@ -0,0 +1,20 @@
[
{
"model_dir": "/mnt/str/models/llama3.1-8b-instruct/qtip/2bit-hyb",
"load_fn": "qtip",
"fwd_fn": "qtip",
"label": "QTIP 2bit-HYB"
},
{
"model_dir": "/mnt/str/models/llama3.1-8b-instruct/qtip/3bit-hyb",
"load_fn": "qtip",
"fwd_fn": "qtip",
"label": "QTIP 3bit-HYB"
},
{
"model_dir": "/mnt/str/models/llama3.1-8b-instruct/qtip/4bit-hyb",
"load_fn": "qtip",
"fwd_fn": "qtip",
"label": "QTIP 4bit-HYB"
}
]