(T5) pad chunks to length of largest chunk (#1990)

When the prompt is chunked using the BREAK keyword, chunks will be padded to the minimum size of 256 tokens - but chunks can be longer. torch.stack then fails if all chunks are not the same size, so find the largest and pad all to match.
#1988 (doesn't quite ID the real issue, prompts longer than 255 tokens work fine)
This commit is contained in:
DenOfEquity
2024-10-07 11:37:06 +01:00
committed by GitHub
parent 1d3f73b78d
commit cc378589a4

View File

@@ -117,9 +117,21 @@ class T5TextProcessingEngine:
else:
chunks, token_count = self.tokenize_line(line)
line_z_values = []
# pad all chunks to length of longest chunk
max_tokens = 0
for chunk in chunks:
max_tokens = max (len(chunk.tokens), max_tokens)
for chunk in chunks:
tokens = chunk.tokens
multipliers = chunk.multipliers
remaining_count = max_tokens - len(tokens)
if remaining_count > 0:
tokens += [self.id_pad] * remaining_count
multipliers += [1.0] * remaining_count
z = self.process_tokens([tokens], [multipliers])[0]
line_z_values.append(z)
cache[line] = line_z_values