Added inbuild plugins and made one for image referenced. WIP

This commit is contained in:
Jaret Burkett
2023-08-10 16:20:38 -06:00
parent 1a7e346b41
commit 67dfd9ced0
2 changed files with 156 additions and 30 deletions

View File

@@ -25,6 +25,7 @@ class ReferenceSliderConfig:
self.resolutions: List[int] = kwargs.get('resolutions', [512])
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
self.target_class: int = kwargs.get('target_class', '')
self.additional_losses: List[str] = kwargs.get('additional_losses', [])
class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
@@ -73,6 +74,8 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
pass
def hook_train_loop(self, batch):
do_mirror_loss = 'mirror' in self.slider_config.additional_losses
with torch.no_grad():
imgs, prompts = batch
dtype = get_torch_dtype(self.train_config.dtype)
@@ -135,62 +138,89 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
timesteps = timesteps.long()
# get noise
noise = self.sd.get_latent_noise(
noise_positive = self.sd.get_latent_noise(
pixel_height=height,
pixel_width=width,
batch_size=batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
if do_mirror_loss:
# mirror the noise
noise_negative = torch.flip(noise_positive.clone(), dims=[3])
else:
noise_negative = noise_positive.clone()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise, timesteps)
noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise, timesteps)
noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise_positive, timesteps)
noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise_negative, timesteps)
noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
noise = torch.cat([noise_positive, noise_negative], dim=0)
timesteps = torch.cat([timesteps, timesteps], dim=0)
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds])
network_multiplier = [1.0, -1.0]
flush()
loss_float = None
loss_slide_float = None
loss_mirror_float = None
self.optimizer.zero_grad()
with self.network:
assert self.network.is_active
loss_list = []
for noisy_latents, network_multiplier in zip(
[noisy_positive_latents, noisy_negative_latents],
[1.0, -1.0],
):
# do positive first
self.network.multiplier = network_multiplier
noise_pred = get_noise_pred(
unconditional_embeds,
conditional_embeds,
1,
timesteps,
noisy_latents
)
# do positive first
self.network.multiplier = network_multiplier
if self.sd.is_v2: # check is vpred, don't want to track it down right now
# v-parameterization training
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
noise_pred = get_noise_pred(
unconditional_embeds,
conditional_embeds,
1,
timesteps,
noisy_latents
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if self.sd.is_v2: # check is vpred, don't want to track it down right now
# v-parameterization training
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
# todo add snr gamma here
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss = loss.mean()
# back propagate loss to free ram
loss.backward()
loss_list.append(loss.item())
# todo add snr gamma here
flush()
loss = loss.mean()
loss_slide_float = loss.item()
if do_mirror_loss:
noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0)
# mirror the negative
noise_pred_neg = torch.flip(noise_pred_neg.clone(), dims=[3])
loss_mirror = torch.nn.functional.mse_loss(noise_pred_pos.float(), noise_pred_neg.float(), reduction="none")
loss_mirror = loss_mirror.mean([1, 2, 3])
loss_mirror = loss_mirror.mean()
loss_mirror_float = loss_mirror.item()
loss += loss_mirror
loss_float = loss.item()
# back propagate loss to free ram
loss.backward()
flush()
# apply gradients
optimizer.step()
lr_scheduler.step()
loss_float = sum(loss_list) / len(loss_list)
# reset network
self.network.multiplier = 1.0
@@ -198,5 +228,9 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
loss_dict = OrderedDict(
{'loss': loss_float},
)
if do_mirror_loss:
loss_dict['l/s'] = loss_slide_float
loss_dict['l/m'] = loss_mirror_float
return loss_dict
# end hook_train_loop