fixed issue with down block residuals when doing slider cfg on sdxl with t2i adapter assisted training

This commit is contained in:
Jaret Burkett
2023-10-01 07:32:48 -06:00
parent 085787b799
commit 560251a24f
2 changed files with 13 additions and 4 deletions

View File

@@ -1,3 +1,4 @@
import copy
import os
import random
from collections import OrderedDict
@@ -302,6 +303,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
pred_kwargs = {}
def get_noise_pred(neg, pos, gs, cts, dn):
down_kwargs = copy.deepcopy(pred_kwargs)
if 'down_block_additional_residuals' in down_kwargs:
dbr_batch_size = down_kwargs['down_block_additional_residuals'][0].shape[0]
if dbr_batch_size != dn.shape[0]:
amount_to_add = int(dn.shape[0] * 2 / dbr_batch_size)
down_kwargs['down_block_additional_residuals'] = [
torch.cat([sample.clone()] * amount_to_add) for sample in down_kwargs['down_block_additional_residuals']
]
return self.sd.predict_noise(
latents=dn,
text_embeddings=train_tools.concat_prompt_embeddings(
@@ -311,7 +320,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
),
timestep=cts,
guidance_scale=gs,
**pred_kwargs
**down_kwargs
)
with torch.no_grad():

View File

@@ -542,8 +542,8 @@ def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_i
# add it to the beginning of the prompt
output_prompt = replace_with + " " + output_prompt
if num_instances > 1:
print(
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
# if num_instances > 1:
# print(
# f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
return output_prompt