fixed issue with down block residuals when doing slider cfg on sdxl with t2i adapter assisted training

This commit is contained in:
Jaret Burkett
2023-10-01 07:32:48 -06:00
parent 085787b799
commit 560251a24f
2 changed files with 13 additions and 4 deletions

View File

@@ -1,3 +1,4 @@
import copy
import os import os
import random import random
from collections import OrderedDict from collections import OrderedDict
@@ -302,6 +303,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
pred_kwargs = {} pred_kwargs = {}
def get_noise_pred(neg, pos, gs, cts, dn): def get_noise_pred(neg, pos, gs, cts, dn):
down_kwargs = copy.deepcopy(pred_kwargs)
if 'down_block_additional_residuals' in down_kwargs:
dbr_batch_size = down_kwargs['down_block_additional_residuals'][0].shape[0]
if dbr_batch_size != dn.shape[0]:
amount_to_add = int(dn.shape[0] * 2 / dbr_batch_size)
down_kwargs['down_block_additional_residuals'] = [
torch.cat([sample.clone()] * amount_to_add) for sample in down_kwargs['down_block_additional_residuals']
]
return self.sd.predict_noise( return self.sd.predict_noise(
latents=dn, latents=dn,
text_embeddings=train_tools.concat_prompt_embeddings( text_embeddings=train_tools.concat_prompt_embeddings(
@@ -311,7 +320,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
), ),
timestep=cts, timestep=cts,
guidance_scale=gs, guidance_scale=gs,
**pred_kwargs **down_kwargs
) )
with torch.no_grad(): with torch.no_grad():

View File

@@ -542,8 +542,8 @@ def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_i
# add it to the beginning of the prompt # add it to the beginning of the prompt
output_prompt = replace_with + " " + output_prompt output_prompt = replace_with + " " + output_prompt
if num_instances > 1: # if num_instances > 1:
print( # print(
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") # f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
return output_prompt return output_prompt