From 463cff0d893f0a46fb9ac0138e13f308d30c852c Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Wed, 7 Aug 2024 21:10:23 -0700 Subject: [PATCH] fix t5 --- backend/text_processing/t5_engine.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/backend/text_processing/t5_engine.py b/backend/text_processing/t5_engine.py index 12999af2..e6c8c55d 100644 --- a/backend/text_processing/t5_engine.py +++ b/backend/text_processing/t5_engine.py @@ -25,7 +25,8 @@ class T5TextProcessingEngine: self.emphasis = emphasis.get_current_option(emphasis_name)() self.min_length = min_length - self.id_end = self.tokenizer('')["input_ids"][0] + self.id_end = 1 + self.id_pad = 0 vocab = self.tokenizer.get_vocab() @@ -81,14 +82,16 @@ class T5TextProcessingEngine: nonlocal token_count nonlocal chunk - token_count += len(chunk.tokens) - to_add = self.min_length - len(chunk.tokens) - 1 - if to_add > 0: - chunk.tokens += [self.id_end] * to_add - chunk.multipliers += [1.0] * to_add - chunk.tokens = chunk.tokens + [self.id_end] chunk.multipliers = chunk.multipliers + [1.0] + current_chunk_length = len(chunk.tokens) + + token_count += current_chunk_length + remaining_count = self.min_length - current_chunk_length + + if remaining_count > 0: + chunk.tokens += [self.id_pad] * remaining_count + chunk.multipliers += [1.0] * remaining_count chunks.append(chunk) chunk = PromptChunk()