mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-23 07:49:24 +00:00
Added support for training lora, dreambooth, and fine tuning. Still need testing and docs
This commit is contained in:
@@ -515,7 +515,8 @@ class StableDiffusion:
|
||||
elif ts_bs * 2 == latent_model_input.shape[0]:
|
||||
timestep = torch.cat([timestep] * 2)
|
||||
else:
|
||||
raise ValueError(f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
|
||||
raise ValueError(
|
||||
f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
@@ -659,6 +660,39 @@ class StableDiffusion:
|
||||
|
||||
raise ValueError(f"Unknown weight name: {name}")
|
||||
|
||||
def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
|
||||
if trigger is None:
|
||||
return prompt
|
||||
output_prompt = prompt
|
||||
default_replacements = ["[name]", "[trigger]"]
|
||||
|
||||
replace_with = trigger
|
||||
if to_replace_list is None:
|
||||
to_replace_list = default_replacements
|
||||
else:
|
||||
to_replace_list += default_replacements
|
||||
|
||||
# remove duplicates
|
||||
to_replace_list = list(set(to_replace_list))
|
||||
|
||||
# replace them all
|
||||
for to_replace in to_replace_list:
|
||||
# replace it
|
||||
output_prompt = output_prompt.replace(to_replace, replace_with)
|
||||
|
||||
# see how many times replace_with is in the prompt
|
||||
num_instances = prompt.count(replace_with)
|
||||
|
||||
if num_instances == 0 and add_if_not_present:
|
||||
# add it to the beginning of the prompt
|
||||
output_prompt = replace_with + " " + output_prompt
|
||||
|
||||
if num_instances > 1:
|
||||
print(
|
||||
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
|
||||
return output_prompt
|
||||
|
||||
def state_dict(self, vae=True, text_encoder=True, unet=True):
|
||||
state_dict = OrderedDict()
|
||||
if vae:
|
||||
|
||||
Reference in New Issue
Block a user