diff --git a/backend/text_processing/t5_engine.py b/backend/text_processing/t5_engine.py index dfeead3a..a6b3dc47 100644 --- a/backend/text_processing/t5_engine.py +++ b/backend/text_processing/t5_engine.py @@ -117,9 +117,21 @@ class T5TextProcessingEngine: else: chunks, token_count = self.tokenize_line(line) line_z_values = [] + + # pad all chunks to length of longest chunk + max_tokens = 0 + for chunk in chunks: + max_tokens = max (len(chunk.tokens), max_tokens) + for chunk in chunks: tokens = chunk.tokens multipliers = chunk.multipliers + + remaining_count = max_tokens - len(tokens) + if remaining_count > 0: + tokens += [self.id_pad] * remaining_count + multipliers += [1.0] * remaining_count + z = self.process_tokens([tokens], [multipliers])[0] line_z_values.append(z) cache[line] = line_z_values