mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Initial training script for photomaker training. Needs a little more work.
This commit is contained in:
@@ -14,6 +14,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileIte
|
||||
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
|
||||
from toolkit.image_utils import show_tensors, show_latents
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
@@ -145,10 +146,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.sd.noise_scheduler._step_index = None
|
||||
|
||||
denoised_latent = self.sd.noise_scheduler.step(
|
||||
pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False
|
||||
pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False
|
||||
)[0]
|
||||
|
||||
residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
|
||||
residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype(
|
||||
self.train_config.dtype))
|
||||
# remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically)
|
||||
denoised_latent = denoised_latent - residual_noise
|
||||
|
||||
@@ -232,7 +234,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pred = noise_pred
|
||||
|
||||
if self.train_config.train_turbo:
|
||||
pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch)
|
||||
pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch)
|
||||
|
||||
ignore_snr = False
|
||||
|
||||
@@ -298,7 +300,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
|
||||
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
|
||||
# add snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma,
|
||||
fixed=True)
|
||||
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
@@ -631,7 +634,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.network.is_active = False
|
||||
can_disable_adapter = False
|
||||
was_adapter_active = False
|
||||
if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ReferenceAdapter)):
|
||||
if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or
|
||||
isinstance(self.adapter, ReferenceAdapter)
|
||||
):
|
||||
can_disable_adapter = True
|
||||
was_adapter_active = self.adapter.is_active
|
||||
self.adapter.is_active = False
|
||||
@@ -698,6 +703,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
batch = self.preprocess_batch(batch)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
# condition the prompt
|
||||
# todo handle more than one adapter image
|
||||
self.adapter.num_control_images = 1
|
||||
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
|
||||
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
if self.train_config.single_item_batching:
|
||||
network_weight_list = network_weight_list + network_weight_list
|
||||
@@ -706,7 +718,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
has_clip_image = batch.clip_image_tensor is not None
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
|
||||
raise ValueError("IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
|
||||
raise ValueError(
|
||||
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
|
||||
|
||||
match_adapter_assist = False
|
||||
|
||||
@@ -752,7 +765,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if batch.clip_image_tensor is not None:
|
||||
clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
|
||||
mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
||||
if batch.mask_tensor is not None:
|
||||
with self.timer('get_mask_multiplier'):
|
||||
@@ -879,12 +891,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
mask_multiplier_list,
|
||||
prompt_2_list
|
||||
):
|
||||
if self.train_config.negative_prompt is not None:
|
||||
# add negative prompt
|
||||
conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in
|
||||
range(len(conditioned_prompts))]
|
||||
if prompt_2 is not None:
|
||||
prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
|
||||
|
||||
# if self.train_config.negative_prompt is not None:
|
||||
# # add negative prompt
|
||||
# conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in
|
||||
# range(len(conditioned_prompts))]
|
||||
# if prompt_2 is not None:
|
||||
# prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
|
||||
|
||||
with network:
|
||||
# encode clip adapter here so embeds are active for tokenizer
|
||||
@@ -977,7 +990,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
|
||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter_embeds'):
|
||||
image_size = self.adapter.input_size
|
||||
@@ -1029,11 +1041,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_clip_embeds = unconditional_clip_embeds.detach()
|
||||
|
||||
|
||||
with self.timer('encode_adapter'):
|
||||
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_embeds = self.adapter(unconditional_embeds.detach(), unconditional_clip_embeds)
|
||||
unconditional_embeds = self.adapter(unconditional_embeds.detach(),
|
||||
unconditional_clip_embeds)
|
||||
|
||||
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
|
||||
# pass in our scheduler
|
||||
@@ -1060,7 +1072,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
|
||||
do_inverted_masked_prior = True
|
||||
|
||||
if ((has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior):
|
||||
if ((
|
||||
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior):
|
||||
with self.timer('prior predict'):
|
||||
prior_pred = self.get_prior_prediction(
|
||||
noisy_latents=noisy_latents,
|
||||
@@ -1074,6 +1087,25 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
unconditional_embeds=unconditional_embeds
|
||||
)
|
||||
|
||||
# do the custom adapter after the prior prediction
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
|
||||
self.adapter.train()
|
||||
conditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
tensors_0_1=clip_images,
|
||||
prompt_embeds=conditional_embeds,
|
||||
is_training=True,
|
||||
has_been_preprocessed=True,
|
||||
)
|
||||
if self.train_config.do_cfg and unconditional_embeds is not None:
|
||||
unconditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
tensors_0_1=clip_images,
|
||||
prompt_embeds=unconditional_embeds,
|
||||
is_training=True,
|
||||
has_been_preprocessed=True,
|
||||
is_unconditional=True
|
||||
)
|
||||
|
||||
|
||||
self.before_unet_predict()
|
||||
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||
if batch.unconditional_latents is not None or self.do_guided_loss:
|
||||
|
||||
Reference in New Issue
Block a user