Initial training script for photomaker training. Needs a little more work.

This commit is contained in:
Jaret Burkett
2024-01-15 18:46:26 -07:00
parent 5276975fb0
commit eebd3c8212
8 changed files with 1183 additions and 24 deletions

View File

@@ -14,6 +14,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileIte
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
from toolkit.image_utils import show_tensors, show_latents
from toolkit.ip_adapter import IPAdapter
from toolkit.custom_adapter import CustomAdapter
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
from toolkit.reference_adapter import ReferenceAdapter
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
@@ -145,10 +146,11 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.noise_scheduler._step_index = None
denoised_latent = self.sd.noise_scheduler.step(
pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False
pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False
)[0]
residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype(
self.train_config.dtype))
# remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically)
denoised_latent = denoised_latent - residual_noise
@@ -232,7 +234,7 @@ class SDTrainer(BaseSDTrainProcess):
pred = noise_pred
if self.train_config.train_turbo:
pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch)
pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch)
ignore_snr = False
@@ -298,7 +300,8 @@ class SDTrainer(BaseSDTrainProcess):
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
# add snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma,
fixed=True)
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
@@ -631,7 +634,9 @@ class SDTrainer(BaseSDTrainProcess):
self.network.is_active = False
can_disable_adapter = False
was_adapter_active = False
if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ReferenceAdapter)):
if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or
isinstance(self.adapter, ReferenceAdapter)
):
can_disable_adapter = True
was_adapter_active = self.adapter.is_active
self.adapter.is_active = False
@@ -698,6 +703,13 @@ class SDTrainer(BaseSDTrainProcess):
batch = self.preprocess_batch(batch)
dtype = get_torch_dtype(self.train_config.dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
if self.adapter and isinstance(self.adapter, CustomAdapter):
# condition the prompt
# todo handle more than one adapter image
self.adapter.num_control_images = 1
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
network_weight_list = batch.get_network_weight_list()
if self.train_config.single_item_batching:
network_weight_list = network_weight_list + network_weight_list
@@ -706,7 +718,8 @@ class SDTrainer(BaseSDTrainProcess):
has_clip_image = batch.clip_image_tensor is not None
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
raise ValueError("IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
raise ValueError(
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
match_adapter_assist = False
@@ -752,7 +765,6 @@ class SDTrainer(BaseSDTrainProcess):
if batch.clip_image_tensor is not None:
clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach()
mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
if batch.mask_tensor is not None:
with self.timer('get_mask_multiplier'):
@@ -879,12 +891,13 @@ class SDTrainer(BaseSDTrainProcess):
mask_multiplier_list,
prompt_2_list
):
if self.train_config.negative_prompt is not None:
# add negative prompt
conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in
range(len(conditioned_prompts))]
if prompt_2 is not None:
prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
# if self.train_config.negative_prompt is not None:
# # add negative prompt
# conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in
# range(len(conditioned_prompts))]
# if prompt_2 is not None:
# prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
with network:
# encode clip adapter here so embeds are active for tokenizer
@@ -977,7 +990,6 @@ class SDTrainer(BaseSDTrainProcess):
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
if self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter_embeds'):
image_size = self.adapter.input_size
@@ -1029,11 +1041,11 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.do_cfg:
unconditional_clip_embeds = unconditional_clip_embeds.detach()
with self.timer('encode_adapter'):
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
if self.train_config.do_cfg:
unconditional_embeds = self.adapter(unconditional_embeds.detach(), unconditional_clip_embeds)
unconditional_embeds = self.adapter(unconditional_embeds.detach(),
unconditional_clip_embeds)
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
# pass in our scheduler
@@ -1060,7 +1072,8 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
do_inverted_masked_prior = True
if ((has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior):
if ((
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior):
with self.timer('prior predict'):
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,
@@ -1074,6 +1087,25 @@ class SDTrainer(BaseSDTrainProcess):
unconditional_embeds=unconditional_embeds
)
# do the custom adapter after the prior prediction
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
self.adapter.train()
conditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=clip_images,
prompt_embeds=conditional_embeds,
is_training=True,
has_been_preprocessed=True,
)
if self.train_config.do_cfg and unconditional_embeds is not None:
unconditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=clip_images,
prompt_embeds=unconditional_embeds,
is_training=True,
has_been_preprocessed=True,
is_unconditional=True
)
self.before_unet_predict()
# do a prior pred if we have an unconditional image, we will swap out the giadance later
if batch.unconditional_latents is not None or self.do_guided_loss: