Initial commit

This commit is contained in:
Jaret Burkett
2023-12-29 13:07:35 -07:00
parent 0892dec4a5
commit bafacf3b65
5 changed files with 175 additions and 37 deletions

View File

@@ -1,6 +1,9 @@
import random
from collections import OrderedDict
from typing import Union, Literal, List
from diffusers import T2IAdapter
from typing import Union, Literal, List, Optional
import numpy as np
from diffusers import T2IAdapter, AutoencoderTiny
import torch.functional as F
from toolkit import train_tools
@@ -39,6 +42,7 @@ class SDTrainer(BaseSDTrainProcess):
self.do_prior_prediction = False
self.do_long_prompts = False
self.do_guided_loss = False
self.taesd: Optional[AutoencoderTiny] = None
def before_model_load(self):
pass
@@ -56,6 +60,16 @@ class SDTrainer(BaseSDTrainProcess):
self.assistant_adapter.eval()
self.assistant_adapter.requires_grad_(False)
flush()
if self.train_config.train_turbo and self.train_config.show_turbo_outputs:
if self.model_config.is_xl:
self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesdxl",
torch_dtype=get_torch_dtype(self.train_config.dtype))
else:
self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesd",
torch_dtype=get_torch_dtype(self.train_config.dtype))
self.taesd.to(dtype=get_torch_dtype(self.train_config.dtype), device=self.device_torch)
self.taesd.eval()
self.taesd.requires_grad_(False)
def hook_before_train_loop(self):
# move vae to device if we did not cache latents
@@ -70,6 +84,96 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter is not None:
self.adapter.to(self.device_torch)
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
# to process turbo learning, we make one big step from our current timestep to the end
# we then denoise the prediction on that remaining step and target our loss to our target latents
# this currently only works on euler_a (that I know of). Would work on others, but needs to be coded to do so.
# needs to be done on each item in batch as they may all have different timesteps
batch_size = pred.shape[0]
pred_chunks = torch.chunk(pred, batch_size, dim=0)
noisy_latents_chunks = torch.chunk(noisy_latents, batch_size, dim=0)
timesteps_chunks = torch.chunk(timesteps, batch_size, dim=0)
latent_chunks = torch.chunk(batch.latents, batch_size, dim=0)
noise_chunks = torch.chunk(noise, batch_size, dim=0)
with torch.no_grad():
# set the timesteps to 1000 so we can capture them to calculate the sigmas
self.sd.noise_scheduler.set_timesteps(
self.sd.noise_scheduler.config.num_train_timesteps,
device=self.device_torch
)
train_timesteps = self.sd.noise_scheduler.timesteps.clone().detach()
train_sigmas = self.sd.noise_scheduler.sigmas.clone().detach()
# set the scheduler to one timestep, we build the step and sigmas for each item in batch for the partial step
self.sd.noise_scheduler.set_timesteps(
1,
device=self.device_torch
)
denoised_pred_chunks = []
target_pred_chunks = []
for i in range(batch_size):
pred_item = pred_chunks[i]
noisy_latents_item = noisy_latents_chunks[i]
timesteps_item = timesteps_chunks[i]
latents_item = latent_chunks[i]
noise_item = noise_chunks[i]
with torch.no_grad():
timestep_idx = [(train_timesteps == t).nonzero().item() for t in timesteps_item][0]
single_step_timestep_schedule = [timesteps_item.squeeze().item()]
# extract the sigma idx for our midpoint timestep
sigmas = train_sigmas[timestep_idx:timestep_idx + 1]
end_sigma_idx = random.randint(timestep_idx, len(train_sigmas) - 1)
end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1]
# add noise to our target
# build the big sigma step. The to step will now be to 0 giving it a full remaining denoising half step
# self.sd.noise_scheduler.sigmas = torch.cat([sigmas, torch.zeros_like(sigmas)]).detach()
self.sd.noise_scheduler.sigmas = torch.cat([sigmas, end_sigma]).detach()
# set our single timstep
self.sd.noise_scheduler.timesteps = torch.from_numpy(
np.array(single_step_timestep_schedule, dtype=np.float32)
).to(device=self.device_torch)
# set the step index to None so it will be recalculated on first step
self.sd.noise_scheduler._step_index = None
denoised_latent = self.sd.noise_scheduler.step(
pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False
)[0]
residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
# remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically)
denoised_latent = denoised_latent - residual_noise
denoised_pred_chunks.append(denoised_latent)
denoised_latents = torch.cat(denoised_pred_chunks, dim=0)
# set the scheduler back to the original timesteps
self.sd.noise_scheduler.set_timesteps(
self.sd.noise_scheduler.config.num_train_timesteps,
device=self.device_torch
)
output = denoised_latents / self.sd.vae.config['scaling_factor']
output = self.sd.vae.decode(output).sample
if self.train_config.show_turbo_outputs:
# since we are completely denoising, we can show them here
with torch.no_grad():
show_tensors(output)
# we return our big partial step denoised latents as our pred and our untouched latents as our target.
# you can do mse against the two here or run the denoised through the vae for pixel space loss against the
# input tensor images.
return output, batch.tensor.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
# you can expand these in a child class to make customization easier
def calculate_loss(
self,
@@ -96,6 +200,7 @@ class SDTrainer(BaseSDTrainProcess):
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
assert not self.train_config.train_turbo
# we need to make the noise prediction be a masked blending of noise and prior_pred
stretched_mask_multiplier = value_map(
mask_multiplier,
@@ -114,6 +219,7 @@ class SDTrainer(BaseSDTrainProcess):
# set masked multiplier to 1.0 so we dont double apply it
# mask_multiplier = 1.0
elif prior_pred is not None:
assert not self.train_config.train_turbo
# matching adapter prediction
target = prior_pred
elif self.sd.prediction_type == 'v_prediction':
@@ -124,9 +230,13 @@ class SDTrainer(BaseSDTrainProcess):
pred = noise_pred
if self.train_config.train_turbo:
pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch)
ignore_snr = False
if loss_target == 'source' or loss_target == 'unaugmented':
assert not self.train_config.train_turbo
# ignore_snr = True
if batch.sigmas is None:
raise ValueError("Batch sigmas is None. This should not happen")
@@ -164,6 +274,7 @@ class SDTrainer(BaseSDTrainProcess):
prior_loss = None
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
assert not self.train_config.train_turbo
# to a loss to unmasked areas of the prior for unmasked regularization
prior_loss = torch.nn.functional.mse_loss(
prior_pred.float(),
@@ -178,15 +289,16 @@ class SDTrainer(BaseSDTrainProcess):
loss = loss + prior_loss
loss = loss.mean([1, 2, 3])
if self.train_config.learnable_snr_gos:
# add snr_gamma
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
# add snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
if not self.train_config.train_turbo:
if self.train_config.learnable_snr_gos:
# add snr_gamma
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
# add snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
return loss