diff --git a/eval/compare_q.py b/eval/compare_q.py index d243eea..1896d30 100644 --- a/eval/compare_q.py +++ b/eval/compare_q.py @@ -42,6 +42,10 @@ from compare_q_anyprecision import ( load_anyprecision, fwd_anyprecision, ) +from compare_q_qtip import ( + load_qtip, + fwd_qtip, +) load_fns = { "transformers_auto_bf16": load_transformers_auto_bf16, @@ -51,6 +55,7 @@ load_fns = { "exllamav3": load_exllamav3, "llamacpp": load_llamacpp, "anyprecision": load_anyprecision, + "qtip": load_qtip, } fwd_fns = { @@ -59,6 +64,7 @@ fwd_fns = { "exllamav3": fwd_exllamav3, "llamacpp": fwd_llamacpp, "anyprecision": fwd_anyprecision, + "qtip": fwd_qtip, } tokenize_fns = { @@ -230,6 +236,7 @@ def plot(results, args): "imat": "brown", "GGUF": "red", "VPTQ": "blue", + "QTIP": "teal", "****": "black", } for k, v in d.items(): diff --git a/eval/compare_q_qtip.py b/eval/compare_q_qtip.py new file mode 100644 index 0000000..0586261 --- /dev/null +++ b/eval/compare_q_qtip.py @@ -0,0 +1,57 @@ +import torch + +""" +Very kludgy wrapper for a custom, hacky install of the QTIP repo as a package. Don't expect this to work generally. +Uses the QTIP package for inference so only supports Llama. +""" + +try: + import qtip + from qtip.lib.linear.quantized_linear import QuantizedLinear + from qtip.lib.utils.unsafe_import import model_from_hf_path +except ModuleNotFoundError: + pass + +def get_tensors_size(tensors): + return 8 * sum(t.element_size() * t.numel() for t in tensors.values() if t is not None) + +def get_tensor_size(tensor): + return 8 * tensor.element_size() * tensor.numel() + +def get_storage_info(model): + sum_bits = 0 + sum_numel = 0 + head_bpw = 0 + head_numel = 0 + for name, module in model.named_modules(): + if any(isinstance(module, x) for x in [torch.nn.Linear]): + if module.out_features >= model.vocab_size * 0.9: + head_bpw = module.weight.element_size() * 8 + head_numel = module.weight.numel() + else: + sum_bits += get_tensor_size(module.weight) + sum_numel += module.weight.numel() + elif any(isinstance(module, x) for x in [QuantizedLinear]): + sum_bits += get_tensors_size({ + "SU": module.SU, + "SV": module.SV, + "tlut": module.tlut, + "trellis": module.trellis, + }) + sum_numel += module.in_features * module.out_features + vram_bits = head_numel * head_bpw + sum_bits + return sum_bits / sum_numel, head_bpw, vram_bits + +@torch.inference_mode +@torch.compiler.disable +def load_qtip(model_dir: str, auto = False, bf16 = False): + model, model_str = model_from_hf_path(model_dir, max_mem_ratio = 0.7) + bpw_layer, bpw_head, vram_bits = get_storage_info(model) + return model, bpw_layer, bpw_head, vram_bits + +@torch.inference_mode +@torch.compiler.disable +def fwd_qtip(model_instance, input_ids: torch.Tensor): + input_ids = input_ids.to("cuda:0") + output = model_instance(input_ids) + return output.logits diff --git a/eval/spec/llama3.1-70b-instruct_qtip.json b/eval/spec/llama3.1-70b-instruct_qtip.json new file mode 100644 index 0000000..a60ddd0 --- /dev/null +++ b/eval/spec/llama3.1-70b-instruct_qtip.json @@ -0,0 +1,20 @@ +[ + { + "model_dir": "/mnt/str/models/llama3.1-70b-instruct/qtip/2bit-hyb-yaqa", + "load_fn": "qtip", + "fwd_fn": "qtip", + "label": "QTIP 2bit-HYB-YAQA" + }, + { + "model_dir": "/mnt/str/models/llama3.1-70b-instruct/qtip/3bit-hyb-yaqa", + "load_fn": "qtip", + "fwd_fn": "qtip", + "label": "QTIP 3bit-HYB-YAQA" + }, + { + "model_dir": "/mnt/str/models/llama3.1-70b-instruct/qtip/4bit-hyb-yaqa", + "load_fn": "qtip", + "fwd_fn": "qtip", + "label": "QTIP 4bit-HYB-YAQA" + }, +] \ No newline at end of file diff --git a/eval/spec/llama3.1-8b-instruct_qtip.json b/eval/spec/llama3.1-8b-instruct_qtip.json new file mode 100644 index 0000000..7ad76a2 --- /dev/null +++ b/eval/spec/llama3.1-8b-instruct_qtip.json @@ -0,0 +1,20 @@ +[ + { + "model_dir": "/mnt/str/models/llama3.1-8b-instruct/qtip/2bit-hyb", + "load_fn": "qtip", + "fwd_fn": "qtip", + "label": "QTIP 2bit-HYB" + }, + { + "model_dir": "/mnt/str/models/llama3.1-8b-instruct/qtip/3bit-hyb", + "load_fn": "qtip", + "fwd_fn": "qtip", + "label": "QTIP 3bit-HYB" + }, + { + "model_dir": "/mnt/str/models/llama3.1-8b-instruct/qtip/4bit-hyb", + "load_fn": "qtip", + "fwd_fn": "qtip", + "label": "QTIP 4bit-HYB" + } +] \ No newline at end of file