mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-07 05:59:57 +00:00
Added inbuild plugins and made one for image referenced. WIP
This commit is contained in:
@@ -25,6 +25,7 @@ class ReferenceSliderConfig:
|
||||
self.resolutions: List[int] = kwargs.get('resolutions', [512])
|
||||
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
|
||||
self.target_class: int = kwargs.get('target_class', '')
|
||||
self.additional_losses: List[str] = kwargs.get('additional_losses', [])
|
||||
|
||||
|
||||
class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
@@ -73,6 +74,8 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
pass
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
do_mirror_loss = 'mirror' in self.slider_config.additional_losses
|
||||
|
||||
with torch.no_grad():
|
||||
imgs, prompts = batch
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
@@ -135,62 +138,89 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
noise_positive = self.sd.get_latent_noise(
|
||||
pixel_height=height,
|
||||
pixel_width=width,
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
if do_mirror_loss:
|
||||
# mirror the noise
|
||||
noise_negative = torch.flip(noise_positive.clone(), dims=[3])
|
||||
else:
|
||||
noise_negative = noise_positive.clone()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise, timesteps)
|
||||
noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise, timesteps)
|
||||
noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise_positive, timesteps)
|
||||
noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise_negative, timesteps)
|
||||
|
||||
noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
|
||||
noise = torch.cat([noise_positive, noise_negative], dim=0)
|
||||
timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||
unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds])
|
||||
network_multiplier = [1.0, -1.0]
|
||||
|
||||
flush()
|
||||
|
||||
loss_float = None
|
||||
loss_slide_float = None
|
||||
loss_mirror_float = None
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
loss_list = []
|
||||
for noisy_latents, network_multiplier in zip(
|
||||
[noisy_positive_latents, noisy_negative_latents],
|
||||
[1.0, -1.0],
|
||||
):
|
||||
# do positive first
|
||||
self.network.multiplier = network_multiplier
|
||||
|
||||
noise_pred = get_noise_pred(
|
||||
unconditional_embeds,
|
||||
conditional_embeds,
|
||||
1,
|
||||
timesteps,
|
||||
noisy_latents
|
||||
)
|
||||
# do positive first
|
||||
self.network.multiplier = network_multiplier
|
||||
|
||||
if self.sd.is_v2: # check is vpred, don't want to track it down right now
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
noise_pred = get_noise_pred(
|
||||
unconditional_embeds,
|
||||
conditional_embeds,
|
||||
1,
|
||||
timesteps,
|
||||
noisy_latents
|
||||
)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
if self.sd.is_v2: # check is vpred, don't want to track it down right now
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
# todo add snr gamma here
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss = loss.mean()
|
||||
# back propagate loss to free ram
|
||||
loss.backward()
|
||||
loss_list.append(loss.item())
|
||||
# todo add snr gamma here
|
||||
|
||||
flush()
|
||||
loss = loss.mean()
|
||||
loss_slide_float = loss.item()
|
||||
|
||||
|
||||
if do_mirror_loss:
|
||||
noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0)
|
||||
# mirror the negative
|
||||
noise_pred_neg = torch.flip(noise_pred_neg.clone(), dims=[3])
|
||||
loss_mirror = torch.nn.functional.mse_loss(noise_pred_pos.float(), noise_pred_neg.float(), reduction="none")
|
||||
loss_mirror = loss_mirror.mean([1, 2, 3])
|
||||
loss_mirror = loss_mirror.mean()
|
||||
loss_mirror_float = loss_mirror.item()
|
||||
loss += loss_mirror
|
||||
|
||||
loss_float = loss.item()
|
||||
|
||||
# back propagate loss to free ram
|
||||
loss.backward()
|
||||
|
||||
flush()
|
||||
|
||||
# apply gradients
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
loss_float = sum(loss_list) / len(loss_list)
|
||||
|
||||
# reset network
|
||||
self.network.multiplier = 1.0
|
||||
@@ -198,5 +228,9 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss_float},
|
||||
)
|
||||
|
||||
if do_mirror_loss:
|
||||
loss_dict['l/s'] = loss_slide_float
|
||||
loss_dict['l/m'] = loss_mirror_float
|
||||
return loss_dict
|
||||
# end hook_train_loop
|
||||
|
||||
Reference in New Issue
Block a user