mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
fixed issue with down block residuals when doing slider cfg on sdxl with t2i adapter assisted training
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
@@ -302,6 +303,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
pred_kwargs = {}
|
||||
|
||||
def get_noise_pred(neg, pos, gs, cts, dn):
|
||||
down_kwargs = copy.deepcopy(pred_kwargs)
|
||||
if 'down_block_additional_residuals' in down_kwargs:
|
||||
dbr_batch_size = down_kwargs['down_block_additional_residuals'][0].shape[0]
|
||||
if dbr_batch_size != dn.shape[0]:
|
||||
amount_to_add = int(dn.shape[0] * 2 / dbr_batch_size)
|
||||
down_kwargs['down_block_additional_residuals'] = [
|
||||
torch.cat([sample.clone()] * amount_to_add) for sample in down_kwargs['down_block_additional_residuals']
|
||||
]
|
||||
return self.sd.predict_noise(
|
||||
latents=dn,
|
||||
text_embeddings=train_tools.concat_prompt_embeddings(
|
||||
@@ -311,7 +320,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
),
|
||||
timestep=cts,
|
||||
guidance_scale=gs,
|
||||
**pred_kwargs
|
||||
**down_kwargs
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user