From 560251a24f8cc6307f93d62ca86fdb12c3982dd7 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 1 Oct 2023 07:32:48 -0600 Subject: [PATCH] fixed issue with down block residuals when doing slider cfg on sdxl with t2i adapter assisted training --- jobs/process/TrainSliderProcess.py | 11 ++++++++++- toolkit/prompt_utils.py | 6 +++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 92f60925..2288efe5 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -1,3 +1,4 @@ +import copy import os import random from collections import OrderedDict @@ -302,6 +303,14 @@ class TrainSliderProcess(BaseSDTrainProcess): pred_kwargs = {} def get_noise_pred(neg, pos, gs, cts, dn): + down_kwargs = copy.deepcopy(pred_kwargs) + if 'down_block_additional_residuals' in down_kwargs: + dbr_batch_size = down_kwargs['down_block_additional_residuals'][0].shape[0] + if dbr_batch_size != dn.shape[0]: + amount_to_add = int(dn.shape[0] * 2 / dbr_batch_size) + down_kwargs['down_block_additional_residuals'] = [ + torch.cat([sample.clone()] * amount_to_add) for sample in down_kwargs['down_block_additional_residuals'] + ] return self.sd.predict_noise( latents=dn, text_embeddings=train_tools.concat_prompt_embeddings( @@ -311,7 +320,7 @@ class TrainSliderProcess(BaseSDTrainProcess): ), timestep=cts, guidance_scale=gs, - **pred_kwargs + **down_kwargs ) with torch.no_grad(): diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index e56cb793..6a5032f3 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -542,8 +542,8 @@ def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_i # add it to the beginning of the prompt output_prompt = replace_with + " " + output_prompt - if num_instances > 1: - print( - f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + # if num_instances > 1: + # print( + # f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") return output_prompt