From cc378589a4ad18ea6dff4ae1a8158d19dcd5486e Mon Sep 17 00:00:00 2001 From: DenOfEquity <166248528+DenOfEquity@users.noreply.github.com> Date: Mon, 7 Oct 2024 11:37:06 +0100 Subject: [PATCH] (T5) pad chunks to length of largest chunk (#1990) When the prompt is chunked using the BREAK keyword, chunks will be padded to the minimum size of 256 tokens - but chunks can be longer. torch.stack then fails if all chunks are not the same size, so find the largest and pad all to match. #1988 (doesn't quite ID the real issue, prompts longer than 255 tokens work fine) --- backend/text_processing/t5_engine.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/backend/text_processing/t5_engine.py b/backend/text_processing/t5_engine.py index dfeead3a..a6b3dc47 100644 --- a/backend/text_processing/t5_engine.py +++ b/backend/text_processing/t5_engine.py @@ -117,9 +117,21 @@ class T5TextProcessingEngine: else: chunks, token_count = self.tokenize_line(line) line_z_values = [] + + # pad all chunks to length of longest chunk + max_tokens = 0 + for chunk in chunks: + max_tokens = max (len(chunk.tokens), max_tokens) + for chunk in chunks: tokens = chunk.tokens multipliers = chunk.multipliers + + remaining_count = max_tokens - len(tokens) + if remaining_count > 0: + tokens += [self.id_pad] * remaining_count + multipliers += [1.0] * remaining_count + z = self.process_tokens([tokens], [multipliers])[0] line_z_values.append(z) cache[line] = line_z_values