From 366db28094ee40719bc8292bfcc397895cd38a89 Mon Sep 17 00:00:00 2001 From: turboderp Date: Wed, 11 Oct 2023 23:42:13 +0200 Subject: [PATCH] Option to return last hidden state from model.forward() --- exllamav2/model.py | 54 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/exllamav2/model.py b/exllamav2/model.py index fc66521..e397224 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -319,7 +319,14 @@ class ExLlamaV2: return attn_mask - def forward(self, input_ids, cache = None, input_mask = None, preprocess_only = False, last_id_only = False, loras = None): + def forward(self, + input_ids, + cache = None, + input_mask = None, + preprocess_only = False, + last_id_only = False, + loras = None, + return_last_state = False): q_len = input_ids.shape[-1] remaining_q_len = q_len @@ -342,7 +349,8 @@ class ExLlamaV2: input_mask = input_mask, preprocess_only = preprocess_only, last_id_only = last_id_only, - loras = loras) + loras = loras, + return_last_state = return_last_state) # Confirm that the input fits within the allocated cache space @@ -352,6 +360,7 @@ class ExLlamaV2: # Split sequence result = None + last_state = None chunk_begin = 0 while chunk_begin < q_len: @@ -378,12 +387,13 @@ class ExLlamaV2: _last_id_only = last_id_only _preprocess_only = preprocess_only or (chunk_end < q_len and last_id_only) - r = self._forward(input_ids = input_ids[:, chunk_begin : chunk_end], - cache = cache, - input_mask = input_mask, - preprocess_only = _preprocess_only, - last_id_only = _last_id_only, - loras = loras) + r, ls = self._forward(input_ids = input_ids[:, chunk_begin : chunk_end], + cache = cache, + input_mask = input_mask, + preprocess_only = _preprocess_only, + last_id_only = _last_id_only, + loras = loras, + return_last_state = return_last_state and remaining_q_len <= chunk_size) if not _preprocess_only: result = r if result is None else torch.cat((result, r), dim = 1) @@ -391,11 +401,22 @@ class ExLlamaV2: chunk_begin = chunk_end remaining_q_len -= chunk_size + last_state = ls - return result + if last_state is None: + return result + else: + return result, last_state - def _forward(self, input_ids, cache = None, input_mask = None, preprocess_only = False, last_id_only = False, loras = None): + def _forward(self, + input_ids, + cache = None, + input_mask = None, + preprocess_only = False, + last_id_only = False, + loras = None, + return_last_state = False): batch_size, seq_len = input_ids.shape past_len = 0 @@ -412,6 +433,7 @@ class ExLlamaV2: x = input_ids prev_device = None attn_mask = None + last_state = None for idx, module in enumerate(self.modules): @@ -427,8 +449,14 @@ class ExLlamaV2: # Onward - if last_id_only and idx == self.head_layer_idx: - x = x.narrow(-2, -1, 1) + if idx == self.head_layer_idx: + if last_id_only and return_last_state: + x = x.narrow(-2, -1, 1) + last_state = x + elif last_id_only: + x = x.narrow(-2, -1, 1) + elif return_last_state: + last_state = x.narrow(-2, -1, 1) x = safe_move_tensor(x, device) x = module.forward(x, cache = cache, attn_mask = attn_mask, past_len = past_len, loras = loras) @@ -455,4 +483,4 @@ class ExLlamaV2: if head_padding > 0: x[:, :, -head_padding:] = -65504. - return x \ No newline at end of file + return x, last_state \ No newline at end of file