fixed issue with down block residuals when doing slider cfg on sdxl with t2i adapter assisted training

This commit is contained in:
Jaret Burkett
2023-10-01 07:32:48 -06:00
parent 085787b799
commit 560251a24f
2 changed files with 13 additions and 4 deletions

View File

@@ -1,3 +1,4 @@
import copy
import os
import random
from collections import OrderedDict
@@ -302,6 +303,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
pred_kwargs = {}
def get_noise_pred(neg, pos, gs, cts, dn):
down_kwargs = copy.deepcopy(pred_kwargs)
if 'down_block_additional_residuals' in down_kwargs:
dbr_batch_size = down_kwargs['down_block_additional_residuals'][0].shape[0]
if dbr_batch_size != dn.shape[0]:
amount_to_add = int(dn.shape[0] * 2 / dbr_batch_size)
down_kwargs['down_block_additional_residuals'] = [
torch.cat([sample.clone()] * amount_to_add) for sample in down_kwargs['down_block_additional_residuals']
]
return self.sd.predict_noise(
latents=dn,
text_embeddings=train_tools.concat_prompt_embeddings(
@@ -311,7 +320,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
),
timestep=cts,
guidance_scale=gs,
**pred_kwargs
**down_kwargs
)
with torch.no_grad():