mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
fixed issue with down block residuals when doing slider cfg on sdxl with t2i adapter assisted training
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import copy
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
@@ -302,6 +303,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
pred_kwargs = {}
|
pred_kwargs = {}
|
||||||
|
|
||||||
def get_noise_pred(neg, pos, gs, cts, dn):
|
def get_noise_pred(neg, pos, gs, cts, dn):
|
||||||
|
down_kwargs = copy.deepcopy(pred_kwargs)
|
||||||
|
if 'down_block_additional_residuals' in down_kwargs:
|
||||||
|
dbr_batch_size = down_kwargs['down_block_additional_residuals'][0].shape[0]
|
||||||
|
if dbr_batch_size != dn.shape[0]:
|
||||||
|
amount_to_add = int(dn.shape[0] * 2 / dbr_batch_size)
|
||||||
|
down_kwargs['down_block_additional_residuals'] = [
|
||||||
|
torch.cat([sample.clone()] * amount_to_add) for sample in down_kwargs['down_block_additional_residuals']
|
||||||
|
]
|
||||||
return self.sd.predict_noise(
|
return self.sd.predict_noise(
|
||||||
latents=dn,
|
latents=dn,
|
||||||
text_embeddings=train_tools.concat_prompt_embeddings(
|
text_embeddings=train_tools.concat_prompt_embeddings(
|
||||||
@@ -311,7 +320,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
),
|
),
|
||||||
timestep=cts,
|
timestep=cts,
|
||||||
guidance_scale=gs,
|
guidance_scale=gs,
|
||||||
**pred_kwargs
|
**down_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
@@ -542,8 +542,8 @@ def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_i
|
|||||||
# add it to the beginning of the prompt
|
# add it to the beginning of the prompt
|
||||||
output_prompt = replace_with + " " + output_prompt
|
output_prompt = replace_with + " " + output_prompt
|
||||||
|
|
||||||
if num_instances > 1:
|
# if num_instances > 1:
|
||||||
print(
|
# print(
|
||||||
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
# f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||||
|
|
||||||
return output_prompt
|
return output_prompt
|
||||||
|
|||||||
Reference in New Issue
Block a user