Refactor to consolidate attn params

This commit is contained in:
turboderp
2024-01-04 04:52:49 +01:00
parent f2e7648d98
commit 41b15dd1c3
11 changed files with 184 additions and 120 deletions

View File

@@ -90,14 +90,14 @@ def test_quant(source: ExLlamaV2Linear,
return variants, variants_bits
def test_error(module, hidden_states, target_states, cache, attn_mask):
def test_error(module, hidden_states, target_states, cache, attn_params):
rfn_sum = 0
rfn_count = 0
for x, xref in zip(hidden_states, target_states):
x = x.cuda()
xref = xref.cuda()
xtest = module.forward(x, cache, attn_mask)
xtest = module.forward(x, cache, attn_params)
xtest = xtest[0].float()
xref = xref[0].float()
rfn_sum += torch.linalg.norm(xtest - xref, 'fro') / torch.linalg.norm(xref, 'fro')
@@ -106,7 +106,7 @@ def test_error(module, hidden_states, target_states, cache, attn_mask):
return max(1e-6, 1 - (rfn_sum / rfn_count)).item()
def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_mask):
def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_params):
qjobs, qmaps = get_qparams_reduced(qparams_attn)
results = []
@@ -141,7 +141,7 @@ def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_m
total_bits += bits_o[o]
total_bpw = total_bits / total_numel
accuracy = test_error(module, hidden_states, target_states, cache, attn_mask)
accuracy = test_error(module, hidden_states, target_states, cache, attn_params)
print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}")
torch.cuda.empty_cache()
@@ -157,7 +157,7 @@ def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_m
return results
def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_mask):
def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_params):
qjobs, qmaps = get_qparams_reduced(qparams_mlp)
results = []
@@ -187,7 +187,7 @@ def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_ma
total_bits += bits_d[d]
total_bpw = total_bits / total_numel
accuracy = test_error(module, hidden_states, target_states, cache, attn_mask)
accuracy = test_error(module, hidden_states, target_states, cache, attn_params)
print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}")
torch.cuda.empty_cache()
@@ -333,7 +333,7 @@ def measure_quant(job, save_fn, model):
# Reference forward pass
cache = None
attn_mask = model.build_attn_mask(1, hidden_states[0].shape[1], 0, None, "cuda:0") if mode == "self_attn" else None
attn_params = ExLlamaV2Attention.Params(1, hidden_states[0].shape[1], 0, None, None) if mode == "self_attn" else None
target_states = []
if mode == "block_sparse_moe":
@@ -342,7 +342,7 @@ def measure_quant(job, save_fn, model):
for i in range(len(hidden_states)):
x = hidden_states[i].to("cuda:0")
outputs = module.forward(x, cache, attn_mask, intermediates = True)
outputs = module.forward(x, cache, attn_params, intermediates = True)
# Hessians
@@ -379,13 +379,13 @@ def measure_quant(job, save_fn, model):
m = None
if mode == "self_attn":
m = measure_attn(module, hidden_states, target_states, quantizers, cache, attn_mask)
m = measure_attn(module, hidden_states, target_states, quantizers, cache, attn_params)
if mode == "mlp":
m = measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_mask)
m = measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_params)
if mode == "block_sparse_moe":
m = measure_moe_mlp(module, hidden_states, target_states, quantizers, cache, attn_mask)
m = measure_moe_mlp(module, hidden_states, target_states, quantizers, cache, attn_params)
measurement[module.key + "." + mode] = m

View File

@@ -106,7 +106,7 @@ def quant_linear(job: dict,
source.linear.weight.data = recons_w.T
def quant_attn(job, module, hidden_states, target_states, quantizers, cache, attn_mask, strat):
def quant_attn(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat):
quantizers["q_proj"].prepare()
quantizers["k_proj"].reuse_h(quantizers["q_proj"])
@@ -119,7 +119,7 @@ def quant_attn(job, module, hidden_states, target_states, quantizers, cache, att
quant_linear(job, module.o_proj, quantizers["o_proj"], strat["o_proj"])
def quant_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_mask, strat):
def quant_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat):
quantizers["gate_proj"].prepare()
quantizers["up_proj"].reuse_h(quantizers["gate_proj"])
@@ -135,7 +135,7 @@ def quant_mlp(job, module, hidden_states, target_states, quantizers, cache, attn
del quantizers[f"down_proj"]
def quant_moe_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_mask, strat):
def quant_moe_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat):
num_experts = module.model.config.num_experts
@@ -154,7 +154,7 @@ def quant_moe_mlp(job, module, hidden_states, target_states, quantizers, cache,
del quantizers[f"w2.{i}"]
def quant_lm_head(job, module, hidden_states, quantizers, cache, attn_mask):
def quant_lm_head(job, module, hidden_states, quantizers, cache, attn_params):
quantizers["lm_head"].prepare()
@@ -269,7 +269,7 @@ def quant(job, save_fn, model):
# Reference forward pass
cache = None
attn_mask = model.build_attn_mask(1, hidden_states[0].shape[1], 0, None, "cuda:0") if mode == "self_attn" else None
attn_params = ExLlamaV2Attention.Params(1, hidden_states[0].shape[1], 0, None, None) if mode == "self_attn" else None
target_states = []
if mode == "block_sparse_moe":
@@ -278,7 +278,7 @@ def quant(job, save_fn, model):
for i in range(len(hidden_states)):
x = hidden_states[i].to("cuda:0")
outputs = module.forward(x, cache, attn_mask, intermediates = True)
outputs = module.forward(x, cache, attn_params, intermediates = True)
# Hessians
@@ -318,18 +318,18 @@ def quant(job, save_fn, model):
if mode == "self_attn":
strat = strategy[module.key + "." + mode]
quant_attn(job, module, hidden_states, target_states, quantizers, cache, attn_mask, strat)
quant_attn(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat)
if mode == "mlp":
strat = strategy[module.key + "." + mode]
quant_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_mask, strat)
quant_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat)
if mode == "block_sparse_moe":
strat = strategy[module.key + "." + mode]
quant_moe_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_mask, strat)
quant_moe_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat)
if mode == "linear":
quant_lm_head(job, module, hidden_states, quantizers, cache, attn_mask)
quant_lm_head(job, module, hidden_states, quantizers, cache, attn_params)
quantizers.clear()
gc.collect()
@@ -352,7 +352,7 @@ def quant(job, save_fn, model):
if mode != "linear":
x = hidden_states[i].to("cuda:0")
output = module.forward(x, cache, attn_mask)
output = module.forward(x, cache, attn_params)
q_states.append(output.to("cpu"))
output = output[0].float()
@@ -365,7 +365,7 @@ def quant(job, save_fn, model):
elif i < job["measurement_rows"]:
x = hidden_states[i].to("cuda:0")
output = module.forward(x, cache, attn_mask)
output = module.forward(x, cache, attn_params)
if module.padding > 0: output = output[:, :, :-module.padding]
logits = output[:, :-1, :]

View File

@@ -18,7 +18,7 @@ import time
# Initialize model and cache
model_directory = "/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.0bpw/"
model_directory = "/mnt/str/models/mistral-7b-instruct-exl2/4.0bpw/"
config = ExLlamaV2Config()
config.model_dir = model_directory

View File

@@ -11,6 +11,7 @@ from exllamav2 import ext
from exllamav2.ext import exllamav2_ext as ext_c
# import xformers.ops as xops
# from exllamav2.util import list_live_tensors, set_snapshot, diff_snapshot, print_vram_usage_peak
from exllamav2.compat import safe_move_tensor
# Detect flash-attn
@@ -51,6 +52,104 @@ class ExLlamaV2Attention(ExLlamaV2Module):
temp_lora_size: int = 0
class Params:
def __init__(self, batch_size, seq_len, past_len, input_mask, position_offsets):
self.batch_size = batch_size
self.seq_len = seq_len
if isinstance(past_len, list):
self.past_len = None
self.past_lens = past_len
self.multi_cache = True
else:
self.past_len = past_len
self.past_lens = None
self.multi_cache = False
self.input_mask = input_mask
self.attn_mask = None
self.attn_masks = None
self.position_offsets = position_offsets
self.past_lens_tensor = None
def is_causal(self):
return self.input_mask is None
def get_position_offsets(self, device):
assert self.position_offsets is not None
if self.position_offsets.device != device:
self.position_offsets = safe_move_tensor(self.position_offsets, device)
return self.position_offsets
def get_past_lens(self, device):
assert self.past_lens is not None
if self.past_lens_tensor is None:
self.past_lens_tensor = torch.tensor(self.past_lens, dtype = torch.int, device = device)
elif self.past_lens_tensor.device != device:
self.past_lens_tensor = safe_move_tensor(self.past_lens_tensor, device)
return self.past_lens_tensor
def get_attn_mask(self, device):
if self.attn_mask is None:
self.attn_mask = self.build_attn_mask(device)
elif self.attn_mask.device != device:
self.attn_mask = safe_move_tensor(self.attn_mask, device)
return self.attn_mask
def get_attn_masks(self, device):
if self.attn_masks is None:
self.attn_masks = self.build_attn_masks(device)
elif self.attn_masks[0] is not None and self.attn_masks[0].device != device:
self.attn_masks = [(safe_move_tensor(m, device) if m is not None else None) for m in self.attn_masks]
return self.attn_masks
def build_single_attn_mask(self, batch_size, seq_len, past_len, device, input_mask):
attn_mask = torch.zeros((batch_size, 1, seq_len, past_len + seq_len), dtype = torch.float16, device = device)
attn_mask_triu = torch.triu(torch.full((seq_len - 1, seq_len - 1), float("-inf")))
attn_mask[:, :, : seq_len - 1, past_len + 1: past_len + seq_len] = attn_mask_triu
if input_mask is not None:
min_mask_width = min(input_mask.shape[-1], seq_len + past_len)
input_mask_part = safe_move_tensor(input_mask[:, :min_mask_width], attn_mask.device)
input_mask_part = input_mask_part.unsqueeze(1).unsqueeze(2)
attn_mask[:, :, :, :min_mask_width] = torch.minimum(attn_mask[:, :, :, :min_mask_width], input_mask_part)
return attn_mask
def build_attn_mask(self, device):
assert not self.multi_cache, "Building single mask for multiple caches"
if self.input_mask is None and self.seq_len == 1: return None
return self.build_single_attn_mask(self.batch_size, self.seq_len, self.past_len, device, self.input_mask)
def build_attn_masks(self, device):
assert self.multi_cache, "Building multiple masks for single cache"
attn_masks = []
for i, past_len in enumerate(self.past_lens):
if self.input_mask is None and self.seq_len == 1:
attn_masks.append(None)
else:
attn_masks.append(self.build_single_attn_mask(1, self.seq_len, past_len, device, self.input_mask[i]))
return attn_masks
def __init__(self, model, key, layer_idx):
super().__init__(model, key)
@@ -225,11 +324,11 @@ class ExLlamaV2Attention(ExLlamaV2Module):
return hidden_states
def forward(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
def forward(self, hidden_states, cache = None, attn_params = None, past_len = None, intermediates = False, loras = None):
global has_flash_attn
if self.q_handle is None or intermediates:
return self.forward_torch(hidden_states, cache, attn_mask, past_len, intermediates, loras = loras, position_offsets = position_offsets)
return self.forward_torch(hidden_states, cache, attn_params, past_len, intermediates, loras = loras)
batch_size = hidden_states.shape[0]
q_len = hidden_states.shape[1]
@@ -278,12 +377,12 @@ class ExLlamaV2Attention(ExLlamaV2Module):
pass_loras = [id(x) for x in loras]
pass_lora_temp = torch.empty((self.temp_lora_size,), dtype = torch.half, device = hidden_states.device)
if isinstance(past_len, tuple):
if attn_params.multi_cache:
pass_past_len_1 = -1
pass_past_len_2 = past_len[0]
elif position_offsets is not None:
pass_past_len_2 = attn_params.get_past_lens(hidden_states.device)
elif attn_params.position_offsets is not None:
pass_past_len_1 = past_len
pass_past_len_2 = position_offsets
pass_past_len_2 = attn_params.get_position_offsets(hidden_states.device)
else:
pass_past_len_1 = past_len
pass_past_len_2 = ext.none_tensor
@@ -336,7 +435,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
# Torch matmul attention
if self.model.config.no_flash_attn or not has_flash_attn:
if self.model.config.no_flash_attn or not has_flash_attn or not attn_params.is_causal():
q_states = q_states.transpose(1, 2)
k_states = k_states.transpose(1, 2)
@@ -350,6 +449,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
q_states = None
attn_weights /= math.sqrt(head_dim)
attn_mask = attn_params.get_attn_mask(hidden_states.device)
if attn_mask is not None: attn_weights = attn_weights + attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16)
@@ -364,6 +464,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
else:
# TODO: Enable flash-attn with input mask
attn_output = flash_attn_func(q_states, k_states, v_states, causal = True)
attn_output = attn_output.reshape((batch_size, q_len, hidden_size))
@@ -394,6 +495,9 @@ class ExLlamaV2Attention(ExLlamaV2Module):
else:
assert attn_params.multi_cache
attn_masks = attn_params.get_attn_masks(hidden_states.device)
attn_outputs = []
for i in range(len(cache)):
@@ -401,20 +505,20 @@ class ExLlamaV2Attention(ExLlamaV2Module):
# Add keys and values to cache
batch_keys, batch_values = cache[i].get_kv_state(self.layer_idx, 1, 0, past_len[1][i].item())
new_keys = batch_keys.narrow(1, past_len[1][i], q_len)
new_values = batch_values.narrow(1, past_len[1][i], q_len)
batch_keys, batch_values = cache[i].get_kv_state(self.layer_idx, 1, 0, past_len[i])
new_keys = batch_keys.narrow(1, past_len[i], q_len)
new_values = batch_values.narrow(1, past_len[i], q_len)
new_keys.copy_(k_states.narrow(0, i, 1))
new_values.copy_(v_states.narrow(0, i, 1))
# Store updated cache values
cache[i].store_kv_state(self.layer_idx, 1, past_len[1][i].item(), q_len)
cache[i].store_kv_state(self.layer_idx, 1, past_len[i], q_len)
# Key/value tensors with past
k_states_b = batch_keys.narrow(1, 0, past_len[1][i] + q_len)
v_states_b = batch_values.narrow(1, 0, past_len[1][i] + q_len)
k_states_b = batch_keys.narrow(1, 0, past_len[i] + q_len)
v_states_b = batch_values.narrow(1, 0, past_len[i] + q_len)
# Torch matmul attention
@@ -432,7 +536,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
k_states_b = None
attn_weights /= math.sqrt(head_dim)
if attn_mask is not None: attn_weights = attn_weights + attn_mask[i]
if attn_masks[i] is not None: attn_weights = attn_weights + attn_masks[i]
attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16)
v_states_b = self.repeat_kv(v_states_b, num_key_value_groups)
@@ -465,7 +569,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
return hidden_states
def forward_torch(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
def forward_torch(self, hidden_states, cache = None, attn_params = None, past_len = None, intermediates = False, loras = None):
num_attention_heads = self.model.config.num_attention_heads
num_key_value_heads = self.model.config.num_key_value_heads
@@ -506,9 +610,13 @@ class ExLlamaV2Attention(ExLlamaV2Module):
constants = self.model.get_device_tensors(self.device_idx, scratch = False)
offset_tensor = position_offsets if position_offsets is not None else ext.none_tensor
ext_c.rope_(query_states, constants.sin, constants.cos, past_len, num_attention_heads, head_dim, offset_tensor)
ext_c.rope_(key_states, constants.sin, constants.cos, past_len, num_key_value_heads, head_dim, offset_tensor)
if attn_params.position_offsets is not None:
position_offsets = attn_params.get_position_offsets(hidden_states.device)
else:
position_offsets = ext.none_tensor
ext_c.rope_(query_states, constants.sin, constants.cos, past_len, num_attention_heads, head_dim, position_offsets)
ext_c.rope_(key_states, constants.sin, constants.cos, past_len, num_key_value_heads, head_dim, position_offsets)
# Add keys and values to cache
@@ -527,7 +635,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
# Torch matmul attention
if self.model.config.no_flash_attn or not has_flash_attn:
if self.model.config.no_flash_attn or not has_flash_attn or not attn_params.is_causal():
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
@@ -538,6 +646,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
attn_weights = torch.matmul(query_states, key_states)
attn_weights /= math.sqrt(head_dim)
attn_mask = attn_params.get_attn_mask(hidden_states.device)
if attn_mask is not None: attn_weights = attn_weights + attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16)

View File

@@ -67,7 +67,7 @@ class ExLlamaV2Embedding(ExLlamaV2Module):
return 0
def forward(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
def forward(self, hidden_states, cache = None, attn_params = None, past_len = None, intermediates = False, loras = None):
hidden_states = self.embedding.forward(hidden_states)

View File

@@ -110,7 +110,7 @@ class ExLlamaV2Linear(ExLlamaV2Module):
return self.out_features * self.model.config.max_input_len * self.model.config.max_batch_size * 4 + 128
def forward(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, force_recons = False, force_cuda = False, position_offsets = None):
def forward(self, hidden_states, cache = None, attn_params = None, past_len = None, intermediates = False, loras = None, force_recons = False, force_cuda = False):
# Linear forward

View File

@@ -158,14 +158,14 @@ class ExLlamaV2MLP(ExLlamaV2Module):
self.down_proj.set_device_idx(idx)
def forward(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
def forward(self, hidden_states, cache = None, attn_params = None, past_len = None, intermediates = False, loras = None):
# global catch_key
#
# if self.key == catch_key:
# return self.forward_torch(hidden_states, cache, attn_mask, intermediates, loras = loras)
# return self.forward_torch(hidden_states, cache, attn_params, intermediates, loras = loras)
if self.q_handle is None or intermediates:
return self.forward_torch(hidden_states, cache, attn_mask, intermediates, loras = loras)
return self.forward_torch(hidden_states, cache, attn_params, intermediates, loras = loras)
if loras is None or self.temp_lora_size == 0:
pass_loras = []
@@ -182,7 +182,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
return hidden_states
def forward_torch(self, hidden_states, cache = None, attn_mask = None, intermediates = False, loras = None, position_offsets = None):
def forward_torch(self, hidden_states, cache = None, attn_params = None, intermediates = False, loras = None, position_offsets = None):
residual = hidden_states
post_norm = self.post_attention_layernorm.forward(hidden_states)

View File

@@ -172,7 +172,7 @@ class ExLlamaV2:
# TODO: Option to reserve space for cache while loading model
state_size = self.config.hidden_size * self.config.max_input_len * self.config.max_batch_size * 2
mask_size = self.config.max_input_len ** 2 * 2
mask_size = self.config.max_input_len ** 2 * self.config.max_batch_size * 2
# Bytes remaining per device
@@ -304,7 +304,7 @@ class ExLlamaV2:
hidden_state = torch.zeros((1, self.config.max_input_len), dtype = torch.long)
batch_size, seq_len = hidden_state.shape
past_len = 0
attn_mask = None
attn_params = ExLlamaV2Attention.Params(batch_size, seq_len, past_len, None, None)
# Size of fixed scratch space
@@ -331,16 +331,16 @@ class ExLlamaV2:
while True:
# If we've reached a new device, allocate fixed tensors and attention mask
# If we've reached a new device, allocate fixed tensors
if current_device > last_touched_device:
self.device_tensors.append(ExLlamaV2DeviceTensors(self, current_device, scratch_fixed))
if attn_mask is not None:
reserved_vram_tensors.append(attn_mask)
attn_mask = safe_move_tensor(attn_mask, _torch_device(current_device))
else:
attn_mask = self.build_attn_mask(batch_size, seq_len, past_len, None, _torch_device(current_device))
# if attn_mask is not None:
# reserved_vram_tensors.append(attn_mask)
# attn_mask = safe_move_tensor(attn_mask, _torch_device(current_device))
# else:
# attn_mask = self.build_attn_mask(batch_size, seq_len, past_len, None, _torch_device(current_device))
b = reserve_vram[current_device]
reserved_vram_tensors.append(torch.empty((b,), dtype = torch.int8, device = _torch_device(current_device)))
@@ -366,7 +366,7 @@ class ExLlamaV2:
hidden_state = hidden_state.narrow(-2, -1, 1)
hidden_state = safe_move_tensor(hidden_state, _torch_device(current_device))
hidden_state = module.forward(hidden_state, cache = cache, attn_mask = attn_mask, past_len = past_len, loras = loras)
hidden_state = module.forward(hidden_state, cache = cache, attn_params = attn_params, past_len = past_len, loras = loras)
fail = False
except Exception as e:
@@ -403,7 +403,7 @@ class ExLlamaV2:
if callback_gen is not None: yield from callback_gen(len(self.modules), len(self.modules))
hidden_state = None
attn_mask = None
attn_params = None
reserved_vram_tensors = None
gc.collect()
@@ -472,44 +472,6 @@ class ExLlamaV2:
return False
def build_attn_mask(self, batch_size, seq_len, past_len, input_mask, device):
if input_mask is None and seq_len == 1: return None
if isinstance(past_len, tuple):
attn_masks = []
for i in range(len(past_len[1])):
attn_mask = torch.zeros((1, 1, seq_len, past_len[1][i] + seq_len), dtype = torch.float16, device = device)
attn_mask_triu = torch.triu(torch.full((seq_len - 1, seq_len - 1), float("-inf")))
attn_mask[:, :, : seq_len - 1, past_len[1][i] + 1: past_len[1][i] + seq_len] = attn_mask_triu
if input_mask is not None:
min_mask_width = min(input_mask[i].shape[-1], seq_len + past_len[1][i])
input_mask_part = safe_move_tensor(input_mask[i][:, :min_mask_width], attn_mask.device)
input_mask_part = input_mask_part.unsqueeze(1).unsqueeze(2)
attn_mask[:, :, :, :min_mask_width] = torch.minimum(attn_mask[:, :, :, :min_mask_width], input_mask_part)
attn_masks.append(attn_mask)
return attn_masks
else:
attn_mask = torch.zeros((batch_size, 1, seq_len, past_len + seq_len), dtype = torch.float16, device = device)
attn_mask_triu = torch.triu(torch.full((seq_len - 1, seq_len - 1), float("-inf")))
attn_mask[:, :, : seq_len - 1, past_len + 1: past_len + seq_len] = attn_mask_triu
if input_mask is not None:
min_mask_width = min(input_mask.shape[-1], seq_len + past_len)
input_mask_part = safe_move_tensor(input_mask[:, :min_mask_width], attn_mask.device)
input_mask_part = input_mask_part.unsqueeze(1).unsqueeze(2)
attn_mask[:, :, :, :min_mask_width] = torch.minimum(attn_mask[:, :, :, :min_mask_width], input_mask_part)
return attn_mask
@torch.inference_mode()
def forward(self,
input_ids,
@@ -626,30 +588,18 @@ class ExLlamaV2:
if isinstance(cache, ExLlamaV2CacheBase):
past_len = cache.current_seq_len
else:
pl = [c.current_seq_len for c in cache]
past_len = torch.tensor(pl, dtype = torch.int)
past_len = (past_len, past_len)
past_len = [c.current_seq_len for c in cache]
# assert cache is None or isinstance(cache, list) or batch_size <= cache.batch_size
x = input_ids
prev_device = None
attn_mask = None
attn_params = ExLlamaV2Attention.Params(batch_size, seq_len, past_len, input_mask, position_offsets)
last_state = None
for idx, module in enumerate(self.modules):
device = _torch_device(module.device_idx)
# Build attention mask
if device != prev_device and device != "cpu":
prev_device = device
attn_mask = self.build_attn_mask(batch_size, seq_len, past_len, input_mask, device)
if isinstance(past_len, tuple): past_len = (safe_move_tensor(past_len[0], device), past_len[1])
if position_offsets is not None: position_offsets = safe_move_tensor(position_offsets, device)
# Onward
if idx == self.head_layer_idx:
@@ -662,7 +612,7 @@ class ExLlamaV2:
last_state = x.narrow(-2, -1, 1)
x = safe_move_tensor(x, device)
x = module.forward(x, cache = cache, attn_mask = attn_mask, past_len = past_len, loras = loras, position_offsets = position_offsets)
x = module.forward(x, cache = cache, attn_params = attn_params, past_len = past_len, loras = loras)
if preprocess_only and idx == self.last_kv_layer_idx:
x = None

View File

@@ -166,7 +166,7 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
self.w3[e].set_device_idx(idx)
def forward(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
def forward(self, hidden_states, cache = None, attn_params = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
batch_size, sequence_length, hidden_dim = hidden_states.shape
@@ -174,7 +174,7 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
# for the LoRA matmuls in order to work with the C++ path
if self.q_handle is None or intermediates or batch_size * sequence_length > 4 or self.num_experts not in [4, 8] or (loras is not None and len(loras) > 0):
return self.forward_torch(hidden_states, cache, attn_mask, intermediates, loras = loras)
return self.forward_torch(hidden_states, cache, attn_params, intermediates, loras = loras)
# if loras is None or self.temp_lora_size == 0:
# pass_loras = []
@@ -183,14 +183,14 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
# pass_loras = [id(x) for x in loras]
# pass_lora_temp = torch.empty((self.temp_lora_size,), dtype = torch.half, device = hidden_states.device)
# ref = self.forward_torch(hidden_states, cache, attn_mask, intermediates, loras = loras)
# ref = self.forward_torch(hidden_states, cache, attn_params, intermediates, loras = loras)
# ext_c.q_moe_mlp_forward_(self.q_handle, hidden_states.view(-1, hidden_states.shape[-1]), pass_loras, pass_lora_temp)
ext_c.q_moe_mlp_forward_(self.q_handle, hidden_states.view(-1, hidden_states.shape[-1]))
return hidden_states
def forward_torch(self, hidden_states, cache = None, attn_mask = None, intermediates = False, loras = None, position_offsets = None):
def forward_torch(self, hidden_states, cache = None, attn_params = None, intermediates = False, loras = None, position_offsets = None):
residual = hidden_states

View File

@@ -55,7 +55,7 @@ class ExLlamaV2RMSNorm(ExLlamaV2Module):
return 0
def forward(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
def forward(self, hidden_states, cache = None, attn_params = None, past_len = None, intermediates = False, loras = None, position_offsets = None):
output_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
@@ -69,7 +69,7 @@ class ExLlamaV2RMSNorm(ExLlamaV2Module):
return hidden_states
def forward_torch(self, hidden_states, cache = None, attn_mask = None, past_len = None, intermediates = False, position_offsets = None):
def forward_torch(self, hidden_states, cache = None, attn_params = None, past_len = None, intermediates = False, position_offsets = None):
hidden_states[hidden_states == -float('inf')] = -65504.0
hidden_states[hidden_states == float('inf')] = 65504.0

View File

@@ -13,6 +13,10 @@ from exllamav2.generator import (
ExLlamaV2Sampler
)
from exllamav2.attn import (
ExLlamaV2Attention
)
import argparse, os, math, time
import pandas, fastparquet
import torch
@@ -262,7 +266,8 @@ if args.eval_dataset or args.standard_perplexity:
sys.stdout.flush()
batch_size, seq_len = eval_tokens.shape
attn_mask = model.build_attn_mask(stream_batch_size, seq_len, 0, None, "cuda:0")
attn_params = ExLlamaV2Attention.Params(stream_batch_size, seq_len, 0, None, None)
# attn_mask = model.build_attn_mask(stream_batch_size, seq_len, 0, None, "cuda:0")
for idx, module in enumerate(model.modules):
module.set_device_idx(-1 if idx == 0 else 0)
@@ -283,7 +288,7 @@ if args.eval_dataset or args.standard_perplexity:
a = b
b = min(b + stream_batch_size, eval_tokens.shape[0])
x = hidden_state[a:b, :, :].to("cuda:0")
x = module.forward(x, cache = None, attn_mask = attn_mask, past_len = 0, loras = None, position_offsets = None)
x = module.forward(x, cache = None, attn_params = attn_params, past_len = 0, loras = None)
if idx < len(model.modules) - 1:
hidden_state[a:b, :, :] = x.to("cpu")