Option to return last hidden state from model.forward()

This commit is contained in:
turboderp
2023-10-11 23:42:13 +02:00
parent 7c740be146
commit 366db28094

View File

@@ -319,7 +319,14 @@ class ExLlamaV2:
return attn_mask
def forward(self, input_ids, cache = None, input_mask = None, preprocess_only = False, last_id_only = False, loras = None):
def forward(self,
input_ids,
cache = None,
input_mask = None,
preprocess_only = False,
last_id_only = False,
loras = None,
return_last_state = False):
q_len = input_ids.shape[-1]
remaining_q_len = q_len
@@ -342,7 +349,8 @@ class ExLlamaV2:
input_mask = input_mask,
preprocess_only = preprocess_only,
last_id_only = last_id_only,
loras = loras)
loras = loras,
return_last_state = return_last_state)
# Confirm that the input fits within the allocated cache space
@@ -352,6 +360,7 @@ class ExLlamaV2:
# Split sequence
result = None
last_state = None
chunk_begin = 0
while chunk_begin < q_len:
@@ -378,12 +387,13 @@ class ExLlamaV2:
_last_id_only = last_id_only
_preprocess_only = preprocess_only or (chunk_end < q_len and last_id_only)
r = self._forward(input_ids = input_ids[:, chunk_begin : chunk_end],
cache = cache,
input_mask = input_mask,
preprocess_only = _preprocess_only,
last_id_only = _last_id_only,
loras = loras)
r, ls = self._forward(input_ids = input_ids[:, chunk_begin : chunk_end],
cache = cache,
input_mask = input_mask,
preprocess_only = _preprocess_only,
last_id_only = _last_id_only,
loras = loras,
return_last_state = return_last_state and remaining_q_len <= chunk_size)
if not _preprocess_only:
result = r if result is None else torch.cat((result, r), dim = 1)
@@ -391,11 +401,22 @@ class ExLlamaV2:
chunk_begin = chunk_end
remaining_q_len -= chunk_size
last_state = ls
return result
if last_state is None:
return result
else:
return result, last_state
def _forward(self, input_ids, cache = None, input_mask = None, preprocess_only = False, last_id_only = False, loras = None):
def _forward(self,
input_ids,
cache = None,
input_mask = None,
preprocess_only = False,
last_id_only = False,
loras = None,
return_last_state = False):
batch_size, seq_len = input_ids.shape
past_len = 0
@@ -412,6 +433,7 @@ class ExLlamaV2:
x = input_ids
prev_device = None
attn_mask = None
last_state = None
for idx, module in enumerate(self.modules):
@@ -427,8 +449,14 @@ class ExLlamaV2:
# Onward
if last_id_only and idx == self.head_layer_idx:
x = x.narrow(-2, -1, 1)
if idx == self.head_layer_idx:
if last_id_only and return_last_state:
x = x.narrow(-2, -1, 1)
last_state = x
elif last_id_only:
x = x.narrow(-2, -1, 1)
elif return_last_state:
last_state = x.narrow(-2, -1, 1)
x = safe_move_tensor(x, device)
x = module.forward(x, cache = cache, attn_mask = attn_mask, past_len = past_len, loras = loras)
@@ -455,4 +483,4 @@ class ExLlamaV2:
if head_padding > 0:
x[:, :, -head_padding:] = -65504.
return x
return x, last_state