Initial commit

This commit is contained in:
Jaret Burkett
2023-12-29 13:07:35 -07:00
parent 0892dec4a5
commit bafacf3b65
5 changed files with 175 additions and 37 deletions

View File

@@ -1,6 +1,9 @@
import random
from collections import OrderedDict
from typing import Union, Literal, List
from diffusers import T2IAdapter
from typing import Union, Literal, List, Optional
import numpy as np
from diffusers import T2IAdapter, AutoencoderTiny
import torch.functional as F
from toolkit import train_tools
@@ -39,6 +42,7 @@ class SDTrainer(BaseSDTrainProcess):
self.do_prior_prediction = False
self.do_long_prompts = False
self.do_guided_loss = False
self.taesd: Optional[AutoencoderTiny] = None
def before_model_load(self):
pass
@@ -56,6 +60,16 @@ class SDTrainer(BaseSDTrainProcess):
self.assistant_adapter.eval()
self.assistant_adapter.requires_grad_(False)
flush()
if self.train_config.train_turbo and self.train_config.show_turbo_outputs:
if self.model_config.is_xl:
self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesdxl",
torch_dtype=get_torch_dtype(self.train_config.dtype))
else:
self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesd",
torch_dtype=get_torch_dtype(self.train_config.dtype))
self.taesd.to(dtype=get_torch_dtype(self.train_config.dtype), device=self.device_torch)
self.taesd.eval()
self.taesd.requires_grad_(False)
def hook_before_train_loop(self):
# move vae to device if we did not cache latents
@@ -70,6 +84,96 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter is not None:
self.adapter.to(self.device_torch)
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
# to process turbo learning, we make one big step from our current timestep to the end
# we then denoise the prediction on that remaining step and target our loss to our target latents
# this currently only works on euler_a (that I know of). Would work on others, but needs to be coded to do so.
# needs to be done on each item in batch as they may all have different timesteps
batch_size = pred.shape[0]
pred_chunks = torch.chunk(pred, batch_size, dim=0)
noisy_latents_chunks = torch.chunk(noisy_latents, batch_size, dim=0)
timesteps_chunks = torch.chunk(timesteps, batch_size, dim=0)
latent_chunks = torch.chunk(batch.latents, batch_size, dim=0)
noise_chunks = torch.chunk(noise, batch_size, dim=0)
with torch.no_grad():
# set the timesteps to 1000 so we can capture them to calculate the sigmas
self.sd.noise_scheduler.set_timesteps(
self.sd.noise_scheduler.config.num_train_timesteps,
device=self.device_torch
)
train_timesteps = self.sd.noise_scheduler.timesteps.clone().detach()
train_sigmas = self.sd.noise_scheduler.sigmas.clone().detach()
# set the scheduler to one timestep, we build the step and sigmas for each item in batch for the partial step
self.sd.noise_scheduler.set_timesteps(
1,
device=self.device_torch
)
denoised_pred_chunks = []
target_pred_chunks = []
for i in range(batch_size):
pred_item = pred_chunks[i]
noisy_latents_item = noisy_latents_chunks[i]
timesteps_item = timesteps_chunks[i]
latents_item = latent_chunks[i]
noise_item = noise_chunks[i]
with torch.no_grad():
timestep_idx = [(train_timesteps == t).nonzero().item() for t in timesteps_item][0]
single_step_timestep_schedule = [timesteps_item.squeeze().item()]
# extract the sigma idx for our midpoint timestep
sigmas = train_sigmas[timestep_idx:timestep_idx + 1]
end_sigma_idx = random.randint(timestep_idx, len(train_sigmas) - 1)
end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1]
# add noise to our target
# build the big sigma step. The to step will now be to 0 giving it a full remaining denoising half step
# self.sd.noise_scheduler.sigmas = torch.cat([sigmas, torch.zeros_like(sigmas)]).detach()
self.sd.noise_scheduler.sigmas = torch.cat([sigmas, end_sigma]).detach()
# set our single timstep
self.sd.noise_scheduler.timesteps = torch.from_numpy(
np.array(single_step_timestep_schedule, dtype=np.float32)
).to(device=self.device_torch)
# set the step index to None so it will be recalculated on first step
self.sd.noise_scheduler._step_index = None
denoised_latent = self.sd.noise_scheduler.step(
pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False
)[0]
residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
# remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically)
denoised_latent = denoised_latent - residual_noise
denoised_pred_chunks.append(denoised_latent)
denoised_latents = torch.cat(denoised_pred_chunks, dim=0)
# set the scheduler back to the original timesteps
self.sd.noise_scheduler.set_timesteps(
self.sd.noise_scheduler.config.num_train_timesteps,
device=self.device_torch
)
output = denoised_latents / self.sd.vae.config['scaling_factor']
output = self.sd.vae.decode(output).sample
if self.train_config.show_turbo_outputs:
# since we are completely denoising, we can show them here
with torch.no_grad():
show_tensors(output)
# we return our big partial step denoised latents as our pred and our untouched latents as our target.
# you can do mse against the two here or run the denoised through the vae for pixel space loss against the
# input tensor images.
return output, batch.tensor.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
# you can expand these in a child class to make customization easier
def calculate_loss(
self,
@@ -96,6 +200,7 @@ class SDTrainer(BaseSDTrainProcess):
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
assert not self.train_config.train_turbo
# we need to make the noise prediction be a masked blending of noise and prior_pred
stretched_mask_multiplier = value_map(
mask_multiplier,
@@ -114,6 +219,7 @@ class SDTrainer(BaseSDTrainProcess):
# set masked multiplier to 1.0 so we dont double apply it
# mask_multiplier = 1.0
elif prior_pred is not None:
assert not self.train_config.train_turbo
# matching adapter prediction
target = prior_pred
elif self.sd.prediction_type == 'v_prediction':
@@ -124,9 +230,13 @@ class SDTrainer(BaseSDTrainProcess):
pred = noise_pred
if self.train_config.train_turbo:
pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch)
ignore_snr = False
if loss_target == 'source' or loss_target == 'unaugmented':
assert not self.train_config.train_turbo
# ignore_snr = True
if batch.sigmas is None:
raise ValueError("Batch sigmas is None. This should not happen")
@@ -164,6 +274,7 @@ class SDTrainer(BaseSDTrainProcess):
prior_loss = None
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
assert not self.train_config.train_turbo
# to a loss to unmasked areas of the prior for unmasked regularization
prior_loss = torch.nn.functional.mse_loss(
prior_pred.float(),
@@ -178,15 +289,16 @@ class SDTrainer(BaseSDTrainProcess):
loss = loss + prior_loss
loss = loss.mean([1, 2, 3])
if self.train_config.learnable_snr_gos:
# add snr_gamma
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
# add snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
if not self.train_config.train_turbo:
if self.train_config.learnable_snr_gos:
# add snr_gamma
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
# add snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
return loss

View File

@@ -198,6 +198,8 @@ class TrainConfig:
self.train_unet = kwargs.get('train_unet', True)
self.train_text_encoder = kwargs.get('train_text_encoder', True)
self.train_refiner = kwargs.get('train_refiner', True)
self.train_turbo = kwargs.get('train_turbo', False)
self.show_turbo_outputs = kwargs.get('show_turbo_outputs', False)
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
self.snr_gamma = kwargs.get('snr_gamma', None)
# trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials
@@ -263,6 +265,9 @@ class TrainConfig:
self.standardize_images = kwargs.get('standardize_images', False)
self.standardize_latents = kwargs.get('standardize_latents', False)
if self.train_turbo and not self.noise_scheduler.startswith("euler"):
raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers")
class ModelConfig:
def __init__(self, **kwargs):

View File

@@ -5,6 +5,7 @@ import json
import os
import io
import struct
import threading
from typing import TYPE_CHECKING
import cv2
@@ -425,43 +426,63 @@ def main(argv=None):
is_window_shown = False
display_lock = threading.Lock()
current_img = None
update_event = threading.Event()
def update_image(img, name):
global current_img
with display_lock:
current_img = (img, name)
update_event.set()
def display_image_in_thread():
global is_window_shown
def display_img():
global current_img
while True:
update_event.wait()
with display_lock:
if current_img:
img, name = current_img
cv2.imshow(name, img)
current_img = None
update_event.clear()
if cv2.waitKey(1) & 0xFF == 27: # Esc key to stop
cv2.destroyAllWindows()
print('\nESC pressed, stopping')
break
if not is_window_shown:
is_window_shown = True
threading.Thread(target=display_img, daemon=True).start()
def show_img(img, name='AI Toolkit'):
global is_window_shown
img = np.clip(img, 0, 255).astype(np.uint8)
cv2.imshow(name, img[:, :, ::-1])
k = cv2.waitKey(1) & 0xFF
if k == 27: # Esc key to stop
print('\nESC pressed, stopping')
raise KeyboardInterrupt
update_image(img[:, :, ::-1], name)
if not is_window_shown:
is_window_shown = True
display_image_in_thread()
def show_tensors(imgs: torch.Tensor, name='AI Toolkit'):
# if rank is 4
if len(imgs.shape) == 4:
img_list = torch.chunk(imgs, imgs.shape[0], dim=0)
else:
img_list = [imgs]
# put images side by side
img = torch.cat(img_list, dim=3)
# img is -1 to 1, convert to 0 to 255
img = img / 2 + 0.5
img_numpy = img.to(torch.float32).detach().cpu().numpy()
img_numpy = np.clip(img_numpy, 0, 1) * 255
# convert to numpy Move channel to last
img_numpy = img_numpy.transpose(0, 2, 3, 1)
# convert to uint8
img_numpy = img_numpy.astype(np.uint8)
show_img(img_numpy[0], name=name)
def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'):
# decode latents
if vae.device == 'cpu':
vae.to(latents.device)
latents = latents / vae.config['scaling_factor']
@@ -469,7 +490,6 @@ def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit
show_tensors(imgs, name=name)
def on_exit():
if is_window_shown:
cv2.destroyAllWindows()

View File

@@ -5,8 +5,8 @@ from typing import Union, List, Optional, Dict, Any, Tuple, Callable
import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
# from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from diffusers.utils import is_torch_xla_available
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
@@ -43,13 +43,14 @@ class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline):
unet=unet,
scheduler=scheduler,
)
self.sampler = None
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
model = ModelWrapper(unet, scheduler.alphas_cumprod)
if scheduler.config.prediction_type == "v_prediction":
self.k_diffusion_model = CompVisVDenoiser(model)
else:
self.k_diffusion_model = CompVisDenoiser(model)
raise NotImplementedError("This pipeline is not implemented yet")
# self.sampler = None
# scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
# model = ModelWrapper(unet, scheduler.alphas_cumprod)
# if scheduler.config.prediction_type == "v_prediction":
# self.k_diffusion_model = CompVisVDenoiser(model)
# else:
# self.k_diffusion_model = CompVisDenoiser(model)
def set_scheduler(self, scheduler_type: str):
library = importlib.import_module("k_diffusion")

View File

@@ -755,7 +755,7 @@ class StableDiffusion:
add_time_ids=None,
conditional_embeddings: Union[PromptEmbeds, None] = None,
unconditional_embeddings: Union[PromptEmbeds, None] = None,
is_input_scaled=True,
is_input_scaled=False,
**kwargs,
):
with torch.no_grad():