mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-03-15 00:07:26 +00:00
Support Qwen2
This commit is contained in:
@@ -465,6 +465,8 @@ class AdaptiveGPTQ:
|
||||
output[key + ".q_invperm"] = self.invperm.to(torch.int)
|
||||
output[key + ".q_scale_max"] = self.qscale_max
|
||||
output[key + ".q_groups"] = self.qgroups
|
||||
if self.layer.bias is not None:
|
||||
output[key + ".bias"] = self.layer.bias.data
|
||||
|
||||
columns = self.columns
|
||||
rem_rows = self.rows
|
||||
|
||||
@@ -91,8 +91,9 @@ def test_quant(source: ExLlamaV2Linear,
|
||||
variants = []
|
||||
variants_bits = []
|
||||
|
||||
original = nn.Linear(source.in_features, source.out_features, False, device = "meta", dtype = torch.float16)
|
||||
original = nn.Linear(source.in_features, source.out_features, source.has_bias, device = "meta", dtype = torch.float16)
|
||||
original.weight = nn.Parameter(source.linear.weight.clone())
|
||||
if source.has_bias: original.bias.weight = nn.Parameter(source.linear.bias.clone())
|
||||
|
||||
for qp in qparams:
|
||||
|
||||
@@ -102,10 +103,12 @@ def test_quant(source: ExLlamaV2Linear,
|
||||
quantized.to("cpu")
|
||||
|
||||
variants.append(quantized)
|
||||
total_bits = qp.total_bits(quantized.weight.T.shape)
|
||||
total_bits = qp.total_bits(quantized.weight.T.shape, original.bias.weight.shape if source.has_bias else None)
|
||||
variants_bits.append(total_bits)
|
||||
|
||||
bpw = total_bits / quantized.weight.numel()
|
||||
numel = quantized.weight.numel()
|
||||
if source.has_bias: numel += original.bias.numel()
|
||||
bpw = total_bits / numel
|
||||
desc = qp.desc
|
||||
|
||||
print(f" -- {source.key:50} {desc:50} {bpw:2.2f} bpw")
|
||||
|
||||
@@ -57,7 +57,7 @@ class QParams:
|
||||
qp_dict["scale_bits"])
|
||||
|
||||
|
||||
def total_bits(self, shape):
|
||||
def total_bits(self, shape, bias_shape = None):
|
||||
|
||||
rows = shape[0]
|
||||
columns = shape[1]
|
||||
@@ -91,16 +91,26 @@ class QParams:
|
||||
total_bits += groups * columns * self.scale_bits # q_scale
|
||||
total_bits += rows * 32 # q_invperm
|
||||
|
||||
if bias_shape is not None:
|
||||
bias_numel = 1
|
||||
for d in bias_shape: bias_numel *= d
|
||||
total_bits += 16 * d
|
||||
|
||||
return total_bits
|
||||
|
||||
|
||||
def bpw(self, shape):
|
||||
def bpw(self, shape, bias_shape = None):
|
||||
|
||||
rows = shape[0]
|
||||
columns = shape[1]
|
||||
numel = rows * columns
|
||||
|
||||
return self.total_bits(shape) / numel
|
||||
if bias_shape is not None:
|
||||
bias_numel = 1
|
||||
for d in bias_shape: bias_numel *= d
|
||||
numel += d
|
||||
|
||||
return self.total_bits(shape, bias_shape) / numel
|
||||
|
||||
|
||||
def get_desc(self, filename = False):
|
||||
|
||||
@@ -55,7 +55,7 @@ def quant_linear(job: dict,
|
||||
# Quantize
|
||||
|
||||
lq.configure(qp.group_size, qp.bits, qp.bits_prop, qp.scale_bits)
|
||||
lq.quantize(keep_qweight = True, apply = True, drop = drop)
|
||||
lq.quantize(keep_qweight = True, apply = True)
|
||||
|
||||
# Pack and save quantized layer
|
||||
|
||||
@@ -69,10 +69,12 @@ def quant_linear(job: dict,
|
||||
|
||||
# Reconstruct from packed layer
|
||||
|
||||
recons_linear = ExLlamaV2Linear(source.model, source.key, source.in_features, source.out_features, False)
|
||||
recons_linear = ExLlamaV2Linear(source.model, source.key, source.in_features, source.out_features, source.has_bias)
|
||||
recons_linear.device_idx = source.device_idx
|
||||
recons_dict = {}
|
||||
for k in ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups"]:
|
||||
recons_keys = ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups"]
|
||||
if source.has_bias: recons_keys += ["bias"]
|
||||
for k in recons_keys:
|
||||
recons_dict[k] = packed_dict[source.key + "." + k]
|
||||
recons_dict["q_perm"] = torch.argsort(recons_dict["q_invperm"]).to(torch.int)
|
||||
recons_linear.load(recons_dict)
|
||||
@@ -86,6 +88,7 @@ def quant_linear(job: dict,
|
||||
recons_w2 = recons_linear.forward(ident, force_cuda = True)
|
||||
|
||||
recons_w2.sub_(quant_w)
|
||||
if recons_linear.has_bias: recons_w2.sub_(recons_dict["bias"])
|
||||
recons_w2.abs_()
|
||||
diff2 = torch.max(recons_w2)
|
||||
|
||||
|
||||
@@ -163,10 +163,10 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
else:
|
||||
self.input_layernorm = ExLlamaV2RMSNorm(model, key + ".input_layernorm")
|
||||
|
||||
self.q_proj = ExLlamaV2Linear(model, key + ".self_attn.q_proj", hidden_size, self.model.config.num_attention_heads * self.model.config.head_dim, False)
|
||||
self.k_proj = ExLlamaV2Linear(model, key + ".self_attn.k_proj", hidden_size, self.model.config.num_key_value_heads * self.model.config.head_dim, False)
|
||||
self.v_proj = ExLlamaV2Linear(model, key + ".self_attn.v_proj", hidden_size, self.model.config.num_key_value_heads * self.model.config.head_dim, False)
|
||||
self.o_proj = ExLlamaV2Linear(model, key + ".self_attn.o_proj", self.model.config.num_attention_heads * self.model.config.head_dim, hidden_size, False)
|
||||
self.q_proj = ExLlamaV2Linear(model, key + ".self_attn.q_proj", hidden_size, self.model.config.num_attention_heads * self.model.config.head_dim, self.model.config.attention_bias_qkv)
|
||||
self.k_proj = ExLlamaV2Linear(model, key + ".self_attn.k_proj", hidden_size, self.model.config.num_key_value_heads * self.model.config.head_dim, self.model.config.attention_bias_qkv)
|
||||
self.v_proj = ExLlamaV2Linear(model, key + ".self_attn.v_proj", hidden_size, self.model.config.num_key_value_heads * self.model.config.head_dim, self.model.config.attention_bias_qkv)
|
||||
self.o_proj = ExLlamaV2Linear(model, key + ".self_attn.o_proj", self.model.config.num_attention_heads * self.model.config.head_dim, hidden_size, self.model.config.attention_bias_o)
|
||||
|
||||
self.submodules = [self.input_layernorm,
|
||||
self.q_proj,
|
||||
|
||||
@@ -44,6 +44,8 @@ class ExLlamaV2Config:
|
||||
head_dim: int = 128 # Constant for all Llama models, except 3b
|
||||
num_experts: int = None
|
||||
num_experts_per_token: int = None
|
||||
attention_bias_qkv: bool = False
|
||||
attention_bias_o: bool = False
|
||||
|
||||
checkpoint_fused_mlp: bool = False
|
||||
|
||||
@@ -144,6 +146,16 @@ class ExLlamaV2Config:
|
||||
expect_keys += \
|
||||
expect_keys_llama
|
||||
|
||||
if "Qwen2ForCausalLM" in read_config["architectures"]:
|
||||
self.architecture = "Qwen2"
|
||||
layer_keys += \
|
||||
layer_keys_llama_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_llama_mlp
|
||||
expect_keys += \
|
||||
expect_keys_llama
|
||||
self.attention_bias_qkv = True
|
||||
self.attention_bias_o = False
|
||||
|
||||
else:
|
||||
print(f" !! Warning, unknown architecture: {repr(read_config['architectures'])}")
|
||||
@@ -167,6 +179,9 @@ class ExLlamaV2Config:
|
||||
self.num_hidden_layers = read_config["num_hidden_layers"]
|
||||
self.rms_norm_eps = read_config["rms_norm_eps"]
|
||||
self.vocab_size = read_config["vocab_size"]
|
||||
if read_config.get("attention_bias", False):
|
||||
self.attention_bias_qkv = True
|
||||
self.attention_bias_o = True
|
||||
|
||||
self.rotary_embedding_base = read_config["rope_theta"] if "rope_theta" in read_config else 10000.0
|
||||
|
||||
|
||||
51
exllamav2/exllamav2_ext/cuda/h_add.cu
Normal file
51
exllamav2/exllamav2_ext/cuda/h_add.cu
Normal file
@@ -0,0 +1,51 @@
|
||||
#include "h_add.cuh"
|
||||
#include "util.cuh"
|
||||
#include "../config.h"
|
||||
#include "matrix_view.cuh"
|
||||
|
||||
#define NUM_THREADS_X 32
|
||||
#define NUM_THREADS_Y 16
|
||||
|
||||
__global__ void cuda_vector_add_kernel
|
||||
(
|
||||
half* __restrict__ dest,
|
||||
const half* __restrict__ source,
|
||||
const int height,
|
||||
const int width
|
||||
)
|
||||
{
|
||||
MatrixView_half_rw dest_(dest, height, width);
|
||||
MatrixView_half source_(source, 1, width);
|
||||
|
||||
int offset_x = blockIdx.x * NUM_THREADS_X * 2 + threadIdx.x * 2;
|
||||
if (offset_x >= width) return;
|
||||
|
||||
int offset_y = blockIdx.y * NUM_THREADS_Y;
|
||||
int end_y = min(offset_y + NUM_THREADS_Y, height);
|
||||
|
||||
half2 v = source_.item_half2(0, offset_x);
|
||||
for (int y = offset_y; y < end_y; ++y)
|
||||
{
|
||||
half2* ptr = (half2*) dest_.item_ptr(y, offset_x);
|
||||
*ptr = __hadd2(v, *ptr);
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_vector_add_
|
||||
(
|
||||
half* dest,
|
||||
const half* source,
|
||||
int height,
|
||||
int width
|
||||
)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = NUM_THREADS_X;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(width, NUM_THREADS_X * 2);
|
||||
gridDim.y = DIVIDE(height, NUM_THREADS_Y);
|
||||
gridDim.z = 1;
|
||||
|
||||
cuda_vector_add_kernel<<<gridDim, blockDim>>>(dest, source, height, width);
|
||||
}
|
||||
18
exllamav2/exllamav2_ext/cuda/h_add.cuh
Normal file
18
exllamav2/exllamav2_ext/cuda/h_add.cuh
Normal file
@@ -0,0 +1,18 @@
|
||||
#ifndef _h_add_cuh
|
||||
#define _h_add_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
void cuda_vector_add_
|
||||
(
|
||||
half* dest,
|
||||
const half* source,
|
||||
int width,
|
||||
int height
|
||||
);
|
||||
|
||||
#endif
|
||||
@@ -5,6 +5,7 @@
|
||||
#define CLEAR_N_SIZE 256
|
||||
|
||||
#include "comp_units/kernel_select.cuh"
|
||||
#include "h_add.cuh"
|
||||
|
||||
void gemm_half_q_half_cuda_part
|
||||
(
|
||||
@@ -165,6 +166,8 @@ void gemm_half_q_half_cuda
|
||||
int block_m = min(size_m, block_m_size_max);
|
||||
gemm_half_q_half_cuda_part(a, b, c, size_m, size_n, size_k, block_m, clear, r_weights, r_weights_stride, mul_r_weights);
|
||||
}
|
||||
|
||||
if (b->cuda_bias) cuda_vector_add_(c, b->cuda_bias, size_m, size_n);
|
||||
}
|
||||
|
||||
__global__ void clear_kernel
|
||||
|
||||
@@ -63,6 +63,8 @@ QMatrix::QMatrix
|
||||
half* _gptq_scales,
|
||||
uint32_t* _gptq_g_idx,
|
||||
|
||||
half* _bias,
|
||||
|
||||
half* _temp_dq
|
||||
) :
|
||||
device(_device),
|
||||
@@ -84,6 +86,7 @@ QMatrix::QMatrix
|
||||
cuda_q_group_map = _q_group_map;
|
||||
cuda_gptq_qzeros = _gptq_qzeros;
|
||||
cuda_gptq_scales = _gptq_scales;
|
||||
cuda_bias = _bias;
|
||||
|
||||
is_gptq = (_gptq_qzeros != NULL);
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ public:
|
||||
uint16_t* cuda_q_group_map = NULL;
|
||||
uint32_t* cuda_gptq_qzeros = NULL;
|
||||
half* cuda_gptq_scales = NULL;
|
||||
half* cuda_bias = NULL;
|
||||
|
||||
half* temp_dq;
|
||||
|
||||
@@ -61,6 +62,8 @@ public:
|
||||
half* _gptq_scales,
|
||||
uint32_t* _gptq_g_idx,
|
||||
|
||||
half* bias,
|
||||
|
||||
half* _temp_dq
|
||||
);
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ uintptr_t make_q_matrix
|
||||
torch::Tensor gptq_qzeros,
|
||||
torch::Tensor gptq_scales,
|
||||
torch::Tensor gptq_g_idx,
|
||||
torch::Tensor bias,
|
||||
torch::Tensor temp_dq
|
||||
)
|
||||
{
|
||||
@@ -42,6 +43,7 @@ uintptr_t make_q_matrix
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(bias, kHalf);
|
||||
|
||||
TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1);
|
||||
|
||||
@@ -65,6 +67,11 @@ uintptr_t make_q_matrix
|
||||
height = q_weight.size(0) * 8;
|
||||
}
|
||||
|
||||
if (!bias.device().is_meta())
|
||||
{
|
||||
TORCH_CHECK_SHAPES(q_weight, 1, bias, 0, 1);
|
||||
}
|
||||
|
||||
TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer")
|
||||
|
||||
QMatrix* m = new QMatrix
|
||||
@@ -83,6 +90,7 @@ uintptr_t make_q_matrix
|
||||
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
|
||||
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
|
||||
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
|
||||
bias.device().is_meta() ? NULL : (half*) bias.data_ptr(),
|
||||
(half*) temp_dq.data_ptr()
|
||||
);
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ uintptr_t make_q_matrix
|
||||
torch::Tensor gptq_qzeros,
|
||||
torch::Tensor gptq_scales,
|
||||
torch::Tensor gptq_g_idx,
|
||||
torch::Tensor bias,
|
||||
torch::Tensor temp_dq
|
||||
);
|
||||
|
||||
|
||||
@@ -117,6 +117,7 @@ if build_jit:
|
||||
"ext_rope.cpp",
|
||||
"ext_safetensors.cpp",
|
||||
"ext_sampling.cpp",
|
||||
"cuda/h_add.cu",
|
||||
"cuda/h_gemm.cu",
|
||||
"cuda/lora.cu",
|
||||
"cuda/pack_tensor.cu",
|
||||
@@ -216,6 +217,7 @@ def make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
w.get("bias", none_tensor),
|
||||
temp_dq)
|
||||
|
||||
# GPTQ
|
||||
@@ -241,6 +243,7 @@ def make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||
w["qzeros"],
|
||||
w["scales"],
|
||||
w["g_idx"].cpu(),
|
||||
w.get("bias", none_tensor),
|
||||
temp_dq)
|
||||
|
||||
# GPTQ without g_idx
|
||||
@@ -257,6 +260,7 @@ def make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||
w["qzeros"],
|
||||
w["scales"],
|
||||
none_tensor,
|
||||
w.get("bias", none_tensor),
|
||||
temp_dq)
|
||||
|
||||
|
||||
|
||||
@@ -199,6 +199,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
return tuple(ret)
|
||||
|
||||
|
||||
# @profile
|
||||
def _stream(self) -> (str, bool, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
|
||||
|
||||
# Token healing
|
||||
@@ -240,9 +241,9 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
self.settings.begin_filters()
|
||||
self.first_token = False
|
||||
|
||||
# Decode the current tail end of the sequence
|
||||
|
||||
old_tail = self.tokenizer.decode(self.sequence_ids[:1, -self.tail_decode_tokens:])[0]
|
||||
# # Decode the current tail end of the sequence
|
||||
#
|
||||
# old_tail = self.tokenizer.decode(self.sequence_ids[:1, -self.tail_decode_tokens:])[0]
|
||||
|
||||
# Generate a single token and append to the sequence
|
||||
|
||||
@@ -253,10 +254,13 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
if next_token.item() in self.stop_tokens:
|
||||
return self.held_text, True, self.no_tokens, self.no_probs, self.no_ptokens, self.no_pprobs, self.no_logits
|
||||
|
||||
# Decode the tail end of the sequence with the added token to get (actual) characters added
|
||||
# # Decode the tail end of the sequence with the added token to get (actual) characters added
|
||||
#
|
||||
# new_tail = self.tokenizer.decode(self.sequence_ids[:1, -(self.tail_decode_tokens + 1):])[0]
|
||||
# new_text = new_tail[len(old_tail):]
|
||||
|
||||
new_tail = self.tokenizer.decode(self.sequence_ids[:1, -(self.tail_decode_tokens + 1):])[0]
|
||||
new_text = new_tail[len(old_tail):]
|
||||
piece_to_id = self.tokenizer.get_id_to_piece_list()
|
||||
new_text = piece_to_id[self.sequence_ids[0, -1].item()]
|
||||
|
||||
next_token, new_text = self._catch_utf8(next_token, new_text)
|
||||
|
||||
|
||||
@@ -44,6 +44,10 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
|
||||
if w is None: w = self.load_weight()
|
||||
if isinstance(w, dict):
|
||||
if self.has_bias:
|
||||
assert "bias" in w, self.key + " has no bias but bias expected"
|
||||
else:
|
||||
assert "bias" not in w, self.key + " has bias but bias is not expected"
|
||||
device_tensors = self.model.get_device_tensors(self.device_idx)
|
||||
device_tensors.begin_scratch_alloc()
|
||||
self.temp_dq = device_tensors.get_scratch_slice(self.temp_dq_size())
|
||||
@@ -51,10 +55,22 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
self.q_handle = ext.make_q_matrix(w, self.temp_dq)
|
||||
|
||||
elif isinstance(w, nn.Parameter):
|
||||
assert not self.has_bias, self.key + " has no bias tensor but bias is expected"
|
||||
if self.padding > 0: w = nn.Parameter(F.pad(w.data, (0, 0, 0, self.padding)).contiguous())
|
||||
self.linear = nn.Linear(self.in_features, self.out_features, self.has_bias, device = "meta", dtype = torch.float16)
|
||||
self.linear.weight = w
|
||||
|
||||
elif isinstance(w, tuple):
|
||||
assert self.has_bias, self.key + " has bias tensor but bias is not expected"
|
||||
ww = w[0]
|
||||
wb = w[1]
|
||||
if self.padding > 0:
|
||||
ww = nn.Parameter(F.pad(ww.data, (0, 0, 0, self.padding)).contiguous())
|
||||
wb = nn.Parameter(F.pad(wb.data, (0, 0, 0, self.padding)).contiguous())
|
||||
self.linear = nn.Linear(self.in_features, self.out_features, self.has_bias, device = "meta", dtype = torch.float16)
|
||||
self.linear.weight = ww
|
||||
self.linear.bias = wb
|
||||
|
||||
|
||||
def matrix_shape(self):
|
||||
|
||||
@@ -129,6 +145,10 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
matrix = self.get_weight_tensor_dq()
|
||||
hidden_states_out = torch.matmul(hidden_states, matrix)
|
||||
|
||||
if self.has_bias:
|
||||
bias = self.get_bias_tensor()
|
||||
hidden_states_out += bias
|
||||
|
||||
# Evaluate LoRAs
|
||||
|
||||
if loras is not None:
|
||||
@@ -198,6 +218,17 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
raise ValueError(f"Layer {self.key} has no data")
|
||||
|
||||
|
||||
def get_bias_tensor(self):
|
||||
|
||||
if self.linear is not None:
|
||||
return self.linear.bias.data
|
||||
|
||||
elif self.q_handle is not None:
|
||||
return self.q_tensors["bias"]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Layer {self.key} has no data")
|
||||
|
||||
def is_quant(self):
|
||||
|
||||
return self.q_handle is not None
|
||||
|
||||
@@ -9,8 +9,6 @@ if sys.version_info < min_version:
|
||||
# Set CUDA context to lazy loading since we won't need 95% of the modules in Torch
|
||||
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
|
||||
|
||||
# Disabled for 0.0.13.post2
|
||||
#
|
||||
# # Set cudaMallocAsync allocator by default as it appears slightly more memory efficient, unless Torch is already
|
||||
# # imported in which case changing the allocator would cause it to crash
|
||||
# if not "PYTORCH_CUDA_ALLOC_CONF" in os.environ:
|
||||
|
||||
@@ -93,14 +93,14 @@ class ExLlamaV2Module:
|
||||
# EXL2
|
||||
|
||||
if key + ".q_weight" in self.model.config.tensor_file_map:
|
||||
qtensors = self.load_multi(["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm"], override_key = override_key)
|
||||
qtensors = self.load_multi(["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm", "bias"], override_key = override_key)
|
||||
qtensors["q_perm"] = torch.argsort(qtensors["q_invperm"]).to(torch.int)
|
||||
return qtensors
|
||||
|
||||
# GPTQ
|
||||
|
||||
if key + ".qweight" in self.model.config.tensor_file_map:
|
||||
qtensors = self.load_multi(["qweight", "qzeros", "scales", "g_idx"], override_key = override_key)
|
||||
qtensors = self.load_multi(["qweight", "qzeros", "scales", "g_idx", "bias"], override_key = override_key)
|
||||
qtensors["scales"] = qtensors["scales"].half()
|
||||
return qtensors
|
||||
|
||||
|
||||
1
setup.py
1
setup.py
@@ -42,6 +42,7 @@ setup_kwargs = {
|
||||
"exllamav2/exllamav2_ext/ext_rope.cpp",
|
||||
"exllamav2/exllamav2_ext/ext_safetensors.cpp",
|
||||
"exllamav2/exllamav2_ext/ext_sampling.cpp",
|
||||
"exllamav2/exllamav2_ext/cuda/h_add.cu",
|
||||
"exllamav2/exllamav2_ext/cuda/h_gemm.cu",
|
||||
"exllamav2/exllamav2_ext/cuda/lora.cu",
|
||||
"exllamav2/exllamav2_ext/cuda/pack_tensor.cu",
|
||||
|
||||
Reference in New Issue
Block a user