mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Refactor to consolidate attn params
This commit is contained in:
@@ -90,14 +90,14 @@ def test_quant(source: ExLlamaV2Linear,
|
||||
return variants, variants_bits
|
||||
|
||||
|
||||
def test_error(module, hidden_states, target_states, cache, attn_mask):
|
||||
def test_error(module, hidden_states, target_states, cache, attn_params):
|
||||
|
||||
rfn_sum = 0
|
||||
rfn_count = 0
|
||||
for x, xref in zip(hidden_states, target_states):
|
||||
x = x.cuda()
|
||||
xref = xref.cuda()
|
||||
xtest = module.forward(x, cache, attn_mask)
|
||||
xtest = module.forward(x, cache, attn_params)
|
||||
xtest = xtest[0].float()
|
||||
xref = xref[0].float()
|
||||
rfn_sum += torch.linalg.norm(xtest - xref, 'fro') / torch.linalg.norm(xref, 'fro')
|
||||
@@ -106,7 +106,7 @@ def test_error(module, hidden_states, target_states, cache, attn_mask):
|
||||
return max(1e-6, 1 - (rfn_sum / rfn_count)).item()
|
||||
|
||||
|
||||
def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_mask):
|
||||
def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_params):
|
||||
|
||||
qjobs, qmaps = get_qparams_reduced(qparams_attn)
|
||||
results = []
|
||||
@@ -141,7 +141,7 @@ def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_m
|
||||
total_bits += bits_o[o]
|
||||
total_bpw = total_bits / total_numel
|
||||
|
||||
accuracy = test_error(module, hidden_states, target_states, cache, attn_mask)
|
||||
accuracy = test_error(module, hidden_states, target_states, cache, attn_params)
|
||||
print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
@@ -157,7 +157,7 @@ def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_m
|
||||
return results
|
||||
|
||||
|
||||
def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_mask):
|
||||
def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_params):
|
||||
|
||||
qjobs, qmaps = get_qparams_reduced(qparams_mlp)
|
||||
results = []
|
||||
@@ -187,7 +187,7 @@ def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_ma
|
||||
total_bits += bits_d[d]
|
||||
total_bpw = total_bits / total_numel
|
||||
|
||||
accuracy = test_error(module, hidden_states, target_states, cache, attn_mask)
|
||||
accuracy = test_error(module, hidden_states, target_states, cache, attn_params)
|
||||
print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
@@ -333,7 +333,7 @@ def measure_quant(job, save_fn, model):
|
||||
# Reference forward pass
|
||||
|
||||
cache = None
|
||||
attn_mask = model.build_attn_mask(1, hidden_states[0].shape[1], 0, None, "cuda:0") if mode == "self_attn" else None
|
||||
attn_params = ExLlamaV2Attention.Params(1, hidden_states[0].shape[1], 0, None, None) if mode == "self_attn" else None
|
||||
|
||||
target_states = []
|
||||
if mode == "block_sparse_moe":
|
||||
@@ -342,7 +342,7 @@ def measure_quant(job, save_fn, model):
|
||||
for i in range(len(hidden_states)):
|
||||
|
||||
x = hidden_states[i].to("cuda:0")
|
||||
outputs = module.forward(x, cache, attn_mask, intermediates = True)
|
||||
outputs = module.forward(x, cache, attn_params, intermediates = True)
|
||||
|
||||
# Hessians
|
||||
|
||||
@@ -379,13 +379,13 @@ def measure_quant(job, save_fn, model):
|
||||
m = None
|
||||
|
||||
if mode == "self_attn":
|
||||
m = measure_attn(module, hidden_states, target_states, quantizers, cache, attn_mask)
|
||||
m = measure_attn(module, hidden_states, target_states, quantizers, cache, attn_params)
|
||||
|
||||
if mode == "mlp":
|
||||
m = measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_mask)
|
||||
m = measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_params)
|
||||
|
||||
if mode == "block_sparse_moe":
|
||||
m = measure_moe_mlp(module, hidden_states, target_states, quantizers, cache, attn_mask)
|
||||
m = measure_moe_mlp(module, hidden_states, target_states, quantizers, cache, attn_params)
|
||||
|
||||
measurement[module.key + "." + mode] = m
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ def quant_linear(job: dict,
|
||||
source.linear.weight.data = recons_w.T
|
||||
|
||||
|
||||
def quant_attn(job, module, hidden_states, target_states, quantizers, cache, attn_mask, strat):
|
||||
def quant_attn(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat):
|
||||
|
||||
quantizers["q_proj"].prepare()
|
||||
quantizers["k_proj"].reuse_h(quantizers["q_proj"])
|
||||
@@ -119,7 +119,7 @@ def quant_attn(job, module, hidden_states, target_states, quantizers, cache, att
|
||||
quant_linear(job, module.o_proj, quantizers["o_proj"], strat["o_proj"])
|
||||
|
||||
|
||||
def quant_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_mask, strat):
|
||||
def quant_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat):
|
||||
|
||||
quantizers["gate_proj"].prepare()
|
||||
quantizers["up_proj"].reuse_h(quantizers["gate_proj"])
|
||||
@@ -135,7 +135,7 @@ def quant_mlp(job, module, hidden_states, target_states, quantizers, cache, attn
|
||||
del quantizers[f"down_proj"]
|
||||
|
||||
|
||||
def quant_moe_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_mask, strat):
|
||||
def quant_moe_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat):
|
||||
|
||||
num_experts = module.model.config.num_experts
|
||||
|
||||
@@ -154,7 +154,7 @@ def quant_moe_mlp(job, module, hidden_states, target_states, quantizers, cache,
|
||||
del quantizers[f"w2.{i}"]
|
||||
|
||||
|
||||
def quant_lm_head(job, module, hidden_states, quantizers, cache, attn_mask):
|
||||
def quant_lm_head(job, module, hidden_states, quantizers, cache, attn_params):
|
||||
|
||||
quantizers["lm_head"].prepare()
|
||||
|
||||
@@ -269,7 +269,7 @@ def quant(job, save_fn, model):
|
||||
# Reference forward pass
|
||||
|
||||
cache = None
|
||||
attn_mask = model.build_attn_mask(1, hidden_states[0].shape[1], 0, None, "cuda:0") if mode == "self_attn" else None
|
||||
attn_params = ExLlamaV2Attention.Params(1, hidden_states[0].shape[1], 0, None, None) if mode == "self_attn" else None
|
||||
|
||||
target_states = []
|
||||
if mode == "block_sparse_moe":
|
||||
@@ -278,7 +278,7 @@ def quant(job, save_fn, model):
|
||||
for i in range(len(hidden_states)):
|
||||
|
||||
x = hidden_states[i].to("cuda:0")
|
||||
outputs = module.forward(x, cache, attn_mask, intermediates = True)
|
||||
outputs = module.forward(x, cache, attn_params, intermediates = True)
|
||||
|
||||
# Hessians
|
||||
|
||||
@@ -318,18 +318,18 @@ def quant(job, save_fn, model):
|
||||
|
||||
if mode == "self_attn":
|
||||
strat = strategy[module.key + "." + mode]
|
||||
quant_attn(job, module, hidden_states, target_states, quantizers, cache, attn_mask, strat)
|
||||
quant_attn(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat)
|
||||
|
||||
if mode == "mlp":
|
||||
strat = strategy[module.key + "." + mode]
|
||||
quant_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_mask, strat)
|
||||
quant_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat)
|
||||
|
||||
if mode == "block_sparse_moe":
|
||||
strat = strategy[module.key + "." + mode]
|
||||
quant_moe_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_mask, strat)
|
||||
quant_moe_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat)
|
||||
|
||||
if mode == "linear":
|
||||
quant_lm_head(job, module, hidden_states, quantizers, cache, attn_mask)
|
||||
quant_lm_head(job, module, hidden_states, quantizers, cache, attn_params)
|
||||
|
||||
quantizers.clear()
|
||||
gc.collect()
|
||||
@@ -352,7 +352,7 @@ def quant(job, save_fn, model):
|
||||
if mode != "linear":
|
||||
|
||||
x = hidden_states[i].to("cuda:0")
|
||||
output = module.forward(x, cache, attn_mask)
|
||||
output = module.forward(x, cache, attn_params)
|
||||
q_states.append(output.to("cpu"))
|
||||
|
||||
output = output[0].float()
|
||||
@@ -365,7 +365,7 @@ def quant(job, save_fn, model):
|
||||
elif i < job["measurement_rows"]:
|
||||
|
||||
x = hidden_states[i].to("cuda:0")
|
||||
output = module.forward(x, cache, attn_mask)
|
||||
output = module.forward(x, cache, attn_params)
|
||||
if module.padding > 0: output = output[:, :, :-module.padding]
|
||||
|
||||
logits = output[:, :-1, :]
|
||||
|
||||
@@ -18,7 +18,7 @@ import time
|
||||
|
||||
# Initialize model and cache
|
||||
|
||||
model_directory = "/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.0bpw/"
|
||||
model_directory = "/mnt/str/models/mistral-7b-instruct-exl2/4.0bpw/"
|
||||
|
||||
config = ExLlamaV2Config()
|
||||
config.model_dir = model_directory
|
||||
|
||||
@@ -11,6 +11,7 @@ from exllamav2 import ext
|
||||
from exllamav2.ext import exllamav2_ext as ext_c
|
||||
# import xformers.ops as xops
|
||||
# from exllamav2.util import list_live_tensors, set_snapshot, diff_snapshot, print_vram_usage_peak
|
||||
from exllamav2.compat import safe_move_tensor
|
||||
|
||||
# Detect flash-attn
|
||||
|
||||
@@ -51,6 +52,104 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
temp_lora_size: int = 0
|
||||
|
||||
|
||||
class Params:
|
||||
|
||||
def __init__(self, batch_size, seq_len, past_len, input_mask, position_offsets):
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.seq_len = seq_len
|
||||
if isinstance(past_len, list):
|
||||
self.past_len = None
|
||||
self.past_lens = past_len
|
||||
self.multi_cache = True
|
||||
else:
|
||||
self.past_len = past_len
|
||||
self.past_lens = None
|
||||
self.multi_cache = False
|
||||
self.input_mask = input_mask
|
||||
|
||||
self.attn_mask = None
|
||||
self.attn_masks = None
|
||||
|
||||
self.position_offsets = position_offsets
|
||||
self.past_lens_tensor = None
|
||||
|
||||
|
||||
def is_causal(self):
|
||||
|
||||
return self.input_mask is None
|
||||
|
||||
|
||||
def get_position_offsets(self, device):
|
||||
|
||||
assert self.position_offsets is not None
|
||||
if self.position_offsets.device != device:
|
||||
self.position_offsets = safe_move_tensor(self.position_offsets, device)
|
||||
return self.position_offsets
|
||||
|
||||
|
||||
def get_past_lens(self, device):
|
||||
|
||||
assert self.past_lens is not None
|
||||
if self.past_lens_tensor is None:
|
||||
self.past_lens_tensor = torch.tensor(self.past_lens, dtype = torch.int, device = device)
|
||||
elif self.past_lens_tensor.device != device:
|
||||
self.past_lens_tensor = safe_move_tensor(self.past_lens_tensor, device)
|
||||
return self.past_lens_tensor
|
||||
|
||||
|
||||
def get_attn_mask(self, device):
|
||||
|
||||
if self.attn_mask is None:
|
||||
self.attn_mask = self.build_attn_mask(device)
|
||||
elif self.attn_mask.device != device:
|
||||
self.attn_mask = safe_move_tensor(self.attn_mask, device)
|
||||
return self.attn_mask
|
||||
|
||||
|
||||
def get_attn_masks(self, device):
|
||||
|
||||
if self.attn_masks is None:
|
||||
self.attn_masks = self.build_attn_masks(device)
|
||||
elif self.attn_masks[0] is not None and self.attn_masks[0].device != device:
|
||||
self.attn_masks = [(safe_move_tensor(m, device) if m is not None else None) for m in self.attn_masks]
|
||||
return self.attn_masks
|
||||
|
||||
|
||||
def build_single_attn_mask(self, batch_size, seq_len, past_len, device, input_mask):
|
||||
|
||||
attn_mask = torch.zeros((batch_size, 1, seq_len, past_len + seq_len), dtype = torch.float16, device = device)
|
||||
attn_mask_triu = torch.triu(torch.full((seq_len - 1, seq_len - 1), float("-inf")))
|
||||
attn_mask[:, :, : seq_len - 1, past_len + 1: past_len + seq_len] = attn_mask_triu
|
||||
|
||||
if input_mask is not None:
|
||||
min_mask_width = min(input_mask.shape[-1], seq_len + past_len)
|
||||
input_mask_part = safe_move_tensor(input_mask[:, :min_mask_width], attn_mask.device)
|
||||
input_mask_part = input_mask_part.unsqueeze(1).unsqueeze(2)
|
||||
attn_mask[:, :, :, :min_mask_width] = torch.minimum(attn_mask[:, :, :, :min_mask_width], input_mask_part)
|
||||
|
||||
return attn_mask
|
||||
|
||||
|
||||
def build_attn_mask(self, device):
|
||||
assert not self.multi_cache, "Building single mask for multiple caches"
|
||||
|
||||
if self.input_mask is None and self.seq_len == 1: return None
|
||||
return self.build_single_attn_mask(self.batch_size, self.seq_len, self.past_len, device, self.input_mask)
|
||||
|
||||
|
||||
def build_attn_masks(self, device):
|
||||
assert self.multi_cache, "Building multiple masks for single cache"
|
||||
|
||||
attn_masks = []
|
||||
for i, past_len in enumerate(self.past_lens):
|
||||
if self.input_mask is None and self.seq_len == 1:
|
||||
attn_masks.append(None)
|
||||
else:
|
||||
attn_masks.append(self.build_single_attn_mask(1, self.seq_len, past_len, device, self.input_mask[i]))
|
||||
return attn_masks
|
||||
|
||||
|
||||
def __init__(self, model, key, layer_idx):
|
||||
super().__init__(model, key)
|
||||
|
||||
@@ -225,11 +324,11 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
|
||||
def forward(self, hidden_states, cache = None, attn_params = None, past_len = None, intermediates = False, loras = None):
|
||||
global has_flash_attn
|
||||
|
||||
if self.q_handle is None or intermediates:
|
||||
return self.forward_torch(hidden_states, cache, attn_mask, past_len, intermediates, loras = loras, position_offsets = position_offsets)
|
||||
return self.forward_torch(hidden_states, cache, attn_params, past_len, intermediates, loras = loras)
|
||||
|
||||
batch_size = hidden_states.shape[0]
|
||||
q_len = hidden_states.shape[1]
|
||||
@@ -278,12 +377,12 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
pass_loras = [id(x) for x in loras]
|
||||
pass_lora_temp = torch.empty((self.temp_lora_size,), dtype = torch.half, device = hidden_states.device)
|
||||
|
||||
if isinstance(past_len, tuple):
|
||||
if attn_params.multi_cache:
|
||||
pass_past_len_1 = -1
|
||||
pass_past_len_2 = past_len[0]
|
||||
elif position_offsets is not None:
|
||||
pass_past_len_2 = attn_params.get_past_lens(hidden_states.device)
|
||||
elif attn_params.position_offsets is not None:
|
||||
pass_past_len_1 = past_len
|
||||
pass_past_len_2 = position_offsets
|
||||
pass_past_len_2 = attn_params.get_position_offsets(hidden_states.device)
|
||||
else:
|
||||
pass_past_len_1 = past_len
|
||||
pass_past_len_2 = ext.none_tensor
|
||||
@@ -336,7 +435,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
|
||||
# Torch matmul attention
|
||||
|
||||
if self.model.config.no_flash_attn or not has_flash_attn:
|
||||
if self.model.config.no_flash_attn or not has_flash_attn or not attn_params.is_causal():
|
||||
|
||||
q_states = q_states.transpose(1, 2)
|
||||
k_states = k_states.transpose(1, 2)
|
||||
@@ -350,6 +449,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
q_states = None
|
||||
|
||||
attn_weights /= math.sqrt(head_dim)
|
||||
attn_mask = attn_params.get_attn_mask(hidden_states.device)
|
||||
if attn_mask is not None: attn_weights = attn_weights + attn_mask
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16)
|
||||
|
||||
@@ -364,6 +464,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
|
||||
else:
|
||||
|
||||
# TODO: Enable flash-attn with input mask
|
||||
attn_output = flash_attn_func(q_states, k_states, v_states, causal = True)
|
||||
attn_output = attn_output.reshape((batch_size, q_len, hidden_size))
|
||||
|
||||
@@ -394,6 +495,9 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
|
||||
else:
|
||||
|
||||
assert attn_params.multi_cache
|
||||
attn_masks = attn_params.get_attn_masks(hidden_states.device)
|
||||
|
||||
attn_outputs = []
|
||||
for i in range(len(cache)):
|
||||
|
||||
@@ -401,20 +505,20 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
|
||||
# Add keys and values to cache
|
||||
|
||||
batch_keys, batch_values = cache[i].get_kv_state(self.layer_idx, 1, 0, past_len[1][i].item())
|
||||
new_keys = batch_keys.narrow(1, past_len[1][i], q_len)
|
||||
new_values = batch_values.narrow(1, past_len[1][i], q_len)
|
||||
batch_keys, batch_values = cache[i].get_kv_state(self.layer_idx, 1, 0, past_len[i])
|
||||
new_keys = batch_keys.narrow(1, past_len[i], q_len)
|
||||
new_values = batch_values.narrow(1, past_len[i], q_len)
|
||||
new_keys.copy_(k_states.narrow(0, i, 1))
|
||||
new_values.copy_(v_states.narrow(0, i, 1))
|
||||
|
||||
# Store updated cache values
|
||||
|
||||
cache[i].store_kv_state(self.layer_idx, 1, past_len[1][i].item(), q_len)
|
||||
cache[i].store_kv_state(self.layer_idx, 1, past_len[i], q_len)
|
||||
|
||||
# Key/value tensors with past
|
||||
|
||||
k_states_b = batch_keys.narrow(1, 0, past_len[1][i] + q_len)
|
||||
v_states_b = batch_values.narrow(1, 0, past_len[1][i] + q_len)
|
||||
k_states_b = batch_keys.narrow(1, 0, past_len[i] + q_len)
|
||||
v_states_b = batch_values.narrow(1, 0, past_len[i] + q_len)
|
||||
|
||||
# Torch matmul attention
|
||||
|
||||
@@ -432,7 +536,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
k_states_b = None
|
||||
|
||||
attn_weights /= math.sqrt(head_dim)
|
||||
if attn_mask is not None: attn_weights = attn_weights + attn_mask[i]
|
||||
if attn_masks[i] is not None: attn_weights = attn_weights + attn_masks[i]
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16)
|
||||
|
||||
v_states_b = self.repeat_kv(v_states_b, num_key_value_groups)
|
||||
@@ -465,7 +569,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward_torch(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
|
||||
def forward_torch(self, hidden_states, cache = None, attn_params = None, past_len = None, intermediates = False, loras = None):
|
||||
|
||||
num_attention_heads = self.model.config.num_attention_heads
|
||||
num_key_value_heads = self.model.config.num_key_value_heads
|
||||
@@ -506,9 +610,13 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
|
||||
constants = self.model.get_device_tensors(self.device_idx, scratch = False)
|
||||
|
||||
offset_tensor = position_offsets if position_offsets is not None else ext.none_tensor
|
||||
ext_c.rope_(query_states, constants.sin, constants.cos, past_len, num_attention_heads, head_dim, offset_tensor)
|
||||
ext_c.rope_(key_states, constants.sin, constants.cos, past_len, num_key_value_heads, head_dim, offset_tensor)
|
||||
if attn_params.position_offsets is not None:
|
||||
position_offsets = attn_params.get_position_offsets(hidden_states.device)
|
||||
else:
|
||||
position_offsets = ext.none_tensor
|
||||
|
||||
ext_c.rope_(query_states, constants.sin, constants.cos, past_len, num_attention_heads, head_dim, position_offsets)
|
||||
ext_c.rope_(key_states, constants.sin, constants.cos, past_len, num_key_value_heads, head_dim, position_offsets)
|
||||
|
||||
# Add keys and values to cache
|
||||
|
||||
@@ -527,7 +635,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
|
||||
# Torch matmul attention
|
||||
|
||||
if self.model.config.no_flash_attn or not has_flash_attn:
|
||||
if self.model.config.no_flash_attn or not has_flash_attn or not attn_params.is_causal():
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
@@ -538,6 +646,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states)
|
||||
attn_weights /= math.sqrt(head_dim)
|
||||
attn_mask = attn_params.get_attn_mask(hidden_states.device)
|
||||
if attn_mask is not None: attn_weights = attn_weights + attn_mask
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16)
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ class ExLlamaV2Embedding(ExLlamaV2Module):
|
||||
return 0
|
||||
|
||||
|
||||
def forward(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
|
||||
def forward(self, hidden_states, cache = None, attn_params = None, past_len = None, intermediates = False, loras = None):
|
||||
|
||||
hidden_states = self.embedding.forward(hidden_states)
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
return self.out_features * self.model.config.max_input_len * self.model.config.max_batch_size * 4 + 128
|
||||
|
||||
|
||||
def forward(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, force_recons = False, force_cuda = False, position_offsets = None):
|
||||
def forward(self, hidden_states, cache = None, attn_params = None, past_len = None, intermediates = False, loras = None, force_recons = False, force_cuda = False):
|
||||
|
||||
# Linear forward
|
||||
|
||||
|
||||
@@ -158,14 +158,14 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
self.down_proj.set_device_idx(idx)
|
||||
|
||||
|
||||
def forward(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
|
||||
def forward(self, hidden_states, cache = None, attn_params = None, past_len = None, intermediates = False, loras = None):
|
||||
# global catch_key
|
||||
#
|
||||
# if self.key == catch_key:
|
||||
# return self.forward_torch(hidden_states, cache, attn_mask, intermediates, loras = loras)
|
||||
# return self.forward_torch(hidden_states, cache, attn_params, intermediates, loras = loras)
|
||||
|
||||
if self.q_handle is None or intermediates:
|
||||
return self.forward_torch(hidden_states, cache, attn_mask, intermediates, loras = loras)
|
||||
return self.forward_torch(hidden_states, cache, attn_params, intermediates, loras = loras)
|
||||
|
||||
if loras is None or self.temp_lora_size == 0:
|
||||
pass_loras = []
|
||||
@@ -182,7 +182,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward_torch(self, hidden_states, cache = None, attn_mask = None, intermediates = False, loras = None, position_offsets = None):
|
||||
def forward_torch(self, hidden_states, cache = None, attn_params = None, intermediates = False, loras = None, position_offsets = None):
|
||||
|
||||
residual = hidden_states
|
||||
post_norm = self.post_attention_layernorm.forward(hidden_states)
|
||||
|
||||
@@ -172,7 +172,7 @@ class ExLlamaV2:
|
||||
# TODO: Option to reserve space for cache while loading model
|
||||
|
||||
state_size = self.config.hidden_size * self.config.max_input_len * self.config.max_batch_size * 2
|
||||
mask_size = self.config.max_input_len ** 2 * 2
|
||||
mask_size = self.config.max_input_len ** 2 * self.config.max_batch_size * 2
|
||||
|
||||
# Bytes remaining per device
|
||||
|
||||
@@ -304,7 +304,7 @@ class ExLlamaV2:
|
||||
hidden_state = torch.zeros((1, self.config.max_input_len), dtype = torch.long)
|
||||
batch_size, seq_len = hidden_state.shape
|
||||
past_len = 0
|
||||
attn_mask = None
|
||||
attn_params = ExLlamaV2Attention.Params(batch_size, seq_len, past_len, None, None)
|
||||
|
||||
# Size of fixed scratch space
|
||||
|
||||
@@ -331,16 +331,16 @@ class ExLlamaV2:
|
||||
|
||||
while True:
|
||||
|
||||
# If we've reached a new device, allocate fixed tensors and attention mask
|
||||
# If we've reached a new device, allocate fixed tensors
|
||||
|
||||
if current_device > last_touched_device:
|
||||
|
||||
self.device_tensors.append(ExLlamaV2DeviceTensors(self, current_device, scratch_fixed))
|
||||
if attn_mask is not None:
|
||||
reserved_vram_tensors.append(attn_mask)
|
||||
attn_mask = safe_move_tensor(attn_mask, _torch_device(current_device))
|
||||
else:
|
||||
attn_mask = self.build_attn_mask(batch_size, seq_len, past_len, None, _torch_device(current_device))
|
||||
# if attn_mask is not None:
|
||||
# reserved_vram_tensors.append(attn_mask)
|
||||
# attn_mask = safe_move_tensor(attn_mask, _torch_device(current_device))
|
||||
# else:
|
||||
# attn_mask = self.build_attn_mask(batch_size, seq_len, past_len, None, _torch_device(current_device))
|
||||
|
||||
b = reserve_vram[current_device]
|
||||
reserved_vram_tensors.append(torch.empty((b,), dtype = torch.int8, device = _torch_device(current_device)))
|
||||
@@ -366,7 +366,7 @@ class ExLlamaV2:
|
||||
hidden_state = hidden_state.narrow(-2, -1, 1)
|
||||
|
||||
hidden_state = safe_move_tensor(hidden_state, _torch_device(current_device))
|
||||
hidden_state = module.forward(hidden_state, cache = cache, attn_mask = attn_mask, past_len = past_len, loras = loras)
|
||||
hidden_state = module.forward(hidden_state, cache = cache, attn_params = attn_params, past_len = past_len, loras = loras)
|
||||
fail = False
|
||||
|
||||
except Exception as e:
|
||||
@@ -403,7 +403,7 @@ class ExLlamaV2:
|
||||
if callback_gen is not None: yield from callback_gen(len(self.modules), len(self.modules))
|
||||
|
||||
hidden_state = None
|
||||
attn_mask = None
|
||||
attn_params = None
|
||||
reserved_vram_tensors = None
|
||||
|
||||
gc.collect()
|
||||
@@ -472,44 +472,6 @@ class ExLlamaV2:
|
||||
return False
|
||||
|
||||
|
||||
def build_attn_mask(self, batch_size, seq_len, past_len, input_mask, device):
|
||||
|
||||
if input_mask is None and seq_len == 1: return None
|
||||
|
||||
if isinstance(past_len, tuple):
|
||||
|
||||
attn_masks = []
|
||||
|
||||
for i in range(len(past_len[1])):
|
||||
|
||||
attn_mask = torch.zeros((1, 1, seq_len, past_len[1][i] + seq_len), dtype = torch.float16, device = device)
|
||||
attn_mask_triu = torch.triu(torch.full((seq_len - 1, seq_len - 1), float("-inf")))
|
||||
attn_mask[:, :, : seq_len - 1, past_len[1][i] + 1: past_len[1][i] + seq_len] = attn_mask_triu
|
||||
|
||||
if input_mask is not None:
|
||||
min_mask_width = min(input_mask[i].shape[-1], seq_len + past_len[1][i])
|
||||
input_mask_part = safe_move_tensor(input_mask[i][:, :min_mask_width], attn_mask.device)
|
||||
input_mask_part = input_mask_part.unsqueeze(1).unsqueeze(2)
|
||||
attn_mask[:, :, :, :min_mask_width] = torch.minimum(attn_mask[:, :, :, :min_mask_width], input_mask_part)
|
||||
|
||||
attn_masks.append(attn_mask)
|
||||
|
||||
return attn_masks
|
||||
|
||||
else:
|
||||
|
||||
attn_mask = torch.zeros((batch_size, 1, seq_len, past_len + seq_len), dtype = torch.float16, device = device)
|
||||
attn_mask_triu = torch.triu(torch.full((seq_len - 1, seq_len - 1), float("-inf")))
|
||||
attn_mask[:, :, : seq_len - 1, past_len + 1: past_len + seq_len] = attn_mask_triu
|
||||
|
||||
if input_mask is not None:
|
||||
min_mask_width = min(input_mask.shape[-1], seq_len + past_len)
|
||||
input_mask_part = safe_move_tensor(input_mask[:, :min_mask_width], attn_mask.device)
|
||||
input_mask_part = input_mask_part.unsqueeze(1).unsqueeze(2)
|
||||
attn_mask[:, :, :, :min_mask_width] = torch.minimum(attn_mask[:, :, :, :min_mask_width], input_mask_part)
|
||||
|
||||
return attn_mask
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self,
|
||||
input_ids,
|
||||
@@ -626,30 +588,18 @@ class ExLlamaV2:
|
||||
if isinstance(cache, ExLlamaV2CacheBase):
|
||||
past_len = cache.current_seq_len
|
||||
else:
|
||||
pl = [c.current_seq_len for c in cache]
|
||||
past_len = torch.tensor(pl, dtype = torch.int)
|
||||
past_len = (past_len, past_len)
|
||||
past_len = [c.current_seq_len for c in cache]
|
||||
|
||||
# assert cache is None or isinstance(cache, list) or batch_size <= cache.batch_size
|
||||
|
||||
x = input_ids
|
||||
prev_device = None
|
||||
attn_mask = None
|
||||
attn_params = ExLlamaV2Attention.Params(batch_size, seq_len, past_len, input_mask, position_offsets)
|
||||
last_state = None
|
||||
|
||||
for idx, module in enumerate(self.modules):
|
||||
|
||||
device = _torch_device(module.device_idx)
|
||||
|
||||
# Build attention mask
|
||||
|
||||
if device != prev_device and device != "cpu":
|
||||
|
||||
prev_device = device
|
||||
attn_mask = self.build_attn_mask(batch_size, seq_len, past_len, input_mask, device)
|
||||
if isinstance(past_len, tuple): past_len = (safe_move_tensor(past_len[0], device), past_len[1])
|
||||
if position_offsets is not None: position_offsets = safe_move_tensor(position_offsets, device)
|
||||
|
||||
# Onward
|
||||
|
||||
if idx == self.head_layer_idx:
|
||||
@@ -662,7 +612,7 @@ class ExLlamaV2:
|
||||
last_state = x.narrow(-2, -1, 1)
|
||||
|
||||
x = safe_move_tensor(x, device)
|
||||
x = module.forward(x, cache = cache, attn_mask = attn_mask, past_len = past_len, loras = loras, position_offsets = position_offsets)
|
||||
x = module.forward(x, cache = cache, attn_params = attn_params, past_len = past_len, loras = loras)
|
||||
|
||||
if preprocess_only and idx == self.last_kv_layer_idx:
|
||||
x = None
|
||||
|
||||
@@ -166,7 +166,7 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
|
||||
self.w3[e].set_device_idx(idx)
|
||||
|
||||
|
||||
def forward(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
|
||||
def forward(self, hidden_states, cache = None, attn_params = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
|
||||
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
|
||||
@@ -174,7 +174,7 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
|
||||
# for the LoRA matmuls in order to work with the C++ path
|
||||
|
||||
if self.q_handle is None or intermediates or batch_size * sequence_length > 4 or self.num_experts not in [4, 8] or (loras is not None and len(loras) > 0):
|
||||
return self.forward_torch(hidden_states, cache, attn_mask, intermediates, loras = loras)
|
||||
return self.forward_torch(hidden_states, cache, attn_params, intermediates, loras = loras)
|
||||
|
||||
# if loras is None or self.temp_lora_size == 0:
|
||||
# pass_loras = []
|
||||
@@ -183,14 +183,14 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
|
||||
# pass_loras = [id(x) for x in loras]
|
||||
# pass_lora_temp = torch.empty((self.temp_lora_size,), dtype = torch.half, device = hidden_states.device)
|
||||
|
||||
# ref = self.forward_torch(hidden_states, cache, attn_mask, intermediates, loras = loras)
|
||||
# ref = self.forward_torch(hidden_states, cache, attn_params, intermediates, loras = loras)
|
||||
# ext_c.q_moe_mlp_forward_(self.q_handle, hidden_states.view(-1, hidden_states.shape[-1]), pass_loras, pass_lora_temp)
|
||||
ext_c.q_moe_mlp_forward_(self.q_handle, hidden_states.view(-1, hidden_states.shape[-1]))
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward_torch(self, hidden_states, cache = None, attn_mask = None, intermediates = False, loras = None, position_offsets = None):
|
||||
def forward_torch(self, hidden_states, cache = None, attn_params = None, intermediates = False, loras = None, position_offsets = None):
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ class ExLlamaV2RMSNorm(ExLlamaV2Module):
|
||||
return 0
|
||||
|
||||
|
||||
def forward(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
|
||||
def forward(self, hidden_states, cache = None, attn_params = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
|
||||
|
||||
output_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
@@ -69,7 +69,7 @@ class ExLlamaV2RMSNorm(ExLlamaV2Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward_torch(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, position_offsets = None):
|
||||
def forward_torch(self, hidden_states, cache = None, attn_params = None, past_len = None, intermediates = False, position_offsets = None):
|
||||
|
||||
hidden_states[hidden_states == -float('inf')] = -65504.0
|
||||
hidden_states[hidden_states == float('inf')] = 65504.0
|
||||
|
||||
@@ -13,6 +13,10 @@ from exllamav2.generator import (
|
||||
ExLlamaV2Sampler
|
||||
)
|
||||
|
||||
from exllamav2.attn import (
|
||||
ExLlamaV2Attention
|
||||
)
|
||||
|
||||
import argparse, os, math, time
|
||||
import pandas, fastparquet
|
||||
import torch
|
||||
@@ -262,7 +266,8 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
sys.stdout.flush()
|
||||
|
||||
batch_size, seq_len = eval_tokens.shape
|
||||
attn_mask = model.build_attn_mask(stream_batch_size, seq_len, 0, None, "cuda:0")
|
||||
attn_params = ExLlamaV2Attention.Params(stream_batch_size, seq_len, 0, None, None)
|
||||
# attn_mask = model.build_attn_mask(stream_batch_size, seq_len, 0, None, "cuda:0")
|
||||
|
||||
for idx, module in enumerate(model.modules):
|
||||
module.set_device_idx(-1 if idx == 0 else 0)
|
||||
@@ -283,7 +288,7 @@ if args.eval_dataset or args.standard_perplexity:
|
||||
a = b
|
||||
b = min(b + stream_batch_size, eval_tokens.shape[0])
|
||||
x = hidden_state[a:b, :, :].to("cuda:0")
|
||||
x = module.forward(x, cache = None, attn_mask = attn_mask, past_len = 0, loras = None, position_offsets = None)
|
||||
x = module.forward(x, cache = None, attn_params = attn_params, past_len = 0, loras = None)
|
||||
|
||||
if idx < len(model.modules) - 1:
|
||||
hidden_state[a:b, :, :] = x.to("cpu")
|
||||
|
||||
Reference in New Issue
Block a user