mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-05-11 08:20:29 +00:00
Option to return last hidden state from model.forward()
This commit is contained in:
@@ -319,7 +319,14 @@ class ExLlamaV2:
|
||||
return attn_mask
|
||||
|
||||
|
||||
def forward(self, input_ids, cache = None, input_mask = None, preprocess_only = False, last_id_only = False, loras = None):
|
||||
def forward(self,
|
||||
input_ids,
|
||||
cache = None,
|
||||
input_mask = None,
|
||||
preprocess_only = False,
|
||||
last_id_only = False,
|
||||
loras = None,
|
||||
return_last_state = False):
|
||||
|
||||
q_len = input_ids.shape[-1]
|
||||
remaining_q_len = q_len
|
||||
@@ -342,7 +349,8 @@ class ExLlamaV2:
|
||||
input_mask = input_mask,
|
||||
preprocess_only = preprocess_only,
|
||||
last_id_only = last_id_only,
|
||||
loras = loras)
|
||||
loras = loras,
|
||||
return_last_state = return_last_state)
|
||||
|
||||
# Confirm that the input fits within the allocated cache space
|
||||
|
||||
@@ -352,6 +360,7 @@ class ExLlamaV2:
|
||||
# Split sequence
|
||||
|
||||
result = None
|
||||
last_state = None
|
||||
|
||||
chunk_begin = 0
|
||||
while chunk_begin < q_len:
|
||||
@@ -378,12 +387,13 @@ class ExLlamaV2:
|
||||
_last_id_only = last_id_only
|
||||
_preprocess_only = preprocess_only or (chunk_end < q_len and last_id_only)
|
||||
|
||||
r = self._forward(input_ids = input_ids[:, chunk_begin : chunk_end],
|
||||
cache = cache,
|
||||
input_mask = input_mask,
|
||||
preprocess_only = _preprocess_only,
|
||||
last_id_only = _last_id_only,
|
||||
loras = loras)
|
||||
r, ls = self._forward(input_ids = input_ids[:, chunk_begin : chunk_end],
|
||||
cache = cache,
|
||||
input_mask = input_mask,
|
||||
preprocess_only = _preprocess_only,
|
||||
last_id_only = _last_id_only,
|
||||
loras = loras,
|
||||
return_last_state = return_last_state and remaining_q_len <= chunk_size)
|
||||
|
||||
if not _preprocess_only:
|
||||
result = r if result is None else torch.cat((result, r), dim = 1)
|
||||
@@ -391,11 +401,22 @@ class ExLlamaV2:
|
||||
|
||||
chunk_begin = chunk_end
|
||||
remaining_q_len -= chunk_size
|
||||
last_state = ls
|
||||
|
||||
return result
|
||||
if last_state is None:
|
||||
return result
|
||||
else:
|
||||
return result, last_state
|
||||
|
||||
|
||||
def _forward(self, input_ids, cache = None, input_mask = None, preprocess_only = False, last_id_only = False, loras = None):
|
||||
def _forward(self,
|
||||
input_ids,
|
||||
cache = None,
|
||||
input_mask = None,
|
||||
preprocess_only = False,
|
||||
last_id_only = False,
|
||||
loras = None,
|
||||
return_last_state = False):
|
||||
|
||||
batch_size, seq_len = input_ids.shape
|
||||
past_len = 0
|
||||
@@ -412,6 +433,7 @@ class ExLlamaV2:
|
||||
x = input_ids
|
||||
prev_device = None
|
||||
attn_mask = None
|
||||
last_state = None
|
||||
|
||||
for idx, module in enumerate(self.modules):
|
||||
|
||||
@@ -427,8 +449,14 @@ class ExLlamaV2:
|
||||
|
||||
# Onward
|
||||
|
||||
if last_id_only and idx == self.head_layer_idx:
|
||||
x = x.narrow(-2, -1, 1)
|
||||
if idx == self.head_layer_idx:
|
||||
if last_id_only and return_last_state:
|
||||
x = x.narrow(-2, -1, 1)
|
||||
last_state = x
|
||||
elif last_id_only:
|
||||
x = x.narrow(-2, -1, 1)
|
||||
elif return_last_state:
|
||||
last_state = x.narrow(-2, -1, 1)
|
||||
|
||||
x = safe_move_tensor(x, device)
|
||||
x = module.forward(x, cache = cache, attn_mask = attn_mask, past_len = past_len, loras = loras)
|
||||
@@ -455,4 +483,4 @@ class ExLlamaV2:
|
||||
if head_padding > 0:
|
||||
x[:, :, -head_padding:] = -65504.
|
||||
|
||||
return x
|
||||
return x, last_state
|
||||
Reference in New Issue
Block a user