mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-06 11:10:10 +00:00
Initial commit
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from typing import Union, Literal, List
|
||||
from diffusers import T2IAdapter
|
||||
from typing import Union, Literal, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from diffusers import T2IAdapter, AutoencoderTiny
|
||||
|
||||
import torch.functional as F
|
||||
from toolkit import train_tools
|
||||
@@ -39,6 +42,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.do_prior_prediction = False
|
||||
self.do_long_prompts = False
|
||||
self.do_guided_loss = False
|
||||
self.taesd: Optional[AutoencoderTiny] = None
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
@@ -56,6 +60,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.assistant_adapter.eval()
|
||||
self.assistant_adapter.requires_grad_(False)
|
||||
flush()
|
||||
if self.train_config.train_turbo and self.train_config.show_turbo_outputs:
|
||||
if self.model_config.is_xl:
|
||||
self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesdxl",
|
||||
torch_dtype=get_torch_dtype(self.train_config.dtype))
|
||||
else:
|
||||
self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesd",
|
||||
torch_dtype=get_torch_dtype(self.train_config.dtype))
|
||||
self.taesd.to(dtype=get_torch_dtype(self.train_config.dtype), device=self.device_torch)
|
||||
self.taesd.eval()
|
||||
self.taesd.requires_grad_(False)
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
# move vae to device if we did not cache latents
|
||||
@@ -70,6 +84,96 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.adapter is not None:
|
||||
self.adapter.to(self.device_torch)
|
||||
|
||||
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
|
||||
# to process turbo learning, we make one big step from our current timestep to the end
|
||||
# we then denoise the prediction on that remaining step and target our loss to our target latents
|
||||
# this currently only works on euler_a (that I know of). Would work on others, but needs to be coded to do so.
|
||||
# needs to be done on each item in batch as they may all have different timesteps
|
||||
batch_size = pred.shape[0]
|
||||
pred_chunks = torch.chunk(pred, batch_size, dim=0)
|
||||
noisy_latents_chunks = torch.chunk(noisy_latents, batch_size, dim=0)
|
||||
timesteps_chunks = torch.chunk(timesteps, batch_size, dim=0)
|
||||
latent_chunks = torch.chunk(batch.latents, batch_size, dim=0)
|
||||
noise_chunks = torch.chunk(noise, batch_size, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
# set the timesteps to 1000 so we can capture them to calculate the sigmas
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.sd.noise_scheduler.config.num_train_timesteps,
|
||||
device=self.device_torch
|
||||
)
|
||||
train_timesteps = self.sd.noise_scheduler.timesteps.clone().detach()
|
||||
|
||||
train_sigmas = self.sd.noise_scheduler.sigmas.clone().detach()
|
||||
|
||||
# set the scheduler to one timestep, we build the step and sigmas for each item in batch for the partial step
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1,
|
||||
device=self.device_torch
|
||||
)
|
||||
|
||||
denoised_pred_chunks = []
|
||||
target_pred_chunks = []
|
||||
|
||||
for i in range(batch_size):
|
||||
pred_item = pred_chunks[i]
|
||||
noisy_latents_item = noisy_latents_chunks[i]
|
||||
timesteps_item = timesteps_chunks[i]
|
||||
latents_item = latent_chunks[i]
|
||||
noise_item = noise_chunks[i]
|
||||
with torch.no_grad():
|
||||
timestep_idx = [(train_timesteps == t).nonzero().item() for t in timesteps_item][0]
|
||||
single_step_timestep_schedule = [timesteps_item.squeeze().item()]
|
||||
# extract the sigma idx for our midpoint timestep
|
||||
sigmas = train_sigmas[timestep_idx:timestep_idx + 1]
|
||||
|
||||
end_sigma_idx = random.randint(timestep_idx, len(train_sigmas) - 1)
|
||||
end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1]
|
||||
|
||||
# add noise to our target
|
||||
|
||||
# build the big sigma step. The to step will now be to 0 giving it a full remaining denoising half step
|
||||
# self.sd.noise_scheduler.sigmas = torch.cat([sigmas, torch.zeros_like(sigmas)]).detach()
|
||||
self.sd.noise_scheduler.sigmas = torch.cat([sigmas, end_sigma]).detach()
|
||||
# set our single timstep
|
||||
self.sd.noise_scheduler.timesteps = torch.from_numpy(
|
||||
np.array(single_step_timestep_schedule, dtype=np.float32)
|
||||
).to(device=self.device_torch)
|
||||
|
||||
# set the step index to None so it will be recalculated on first step
|
||||
self.sd.noise_scheduler._step_index = None
|
||||
|
||||
denoised_latent = self.sd.noise_scheduler.step(
|
||||
pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False
|
||||
)[0]
|
||||
|
||||
residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
|
||||
# remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically)
|
||||
denoised_latent = denoised_latent - residual_noise
|
||||
|
||||
denoised_pred_chunks.append(denoised_latent)
|
||||
|
||||
denoised_latents = torch.cat(denoised_pred_chunks, dim=0)
|
||||
# set the scheduler back to the original timesteps
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.sd.noise_scheduler.config.num_train_timesteps,
|
||||
device=self.device_torch
|
||||
)
|
||||
|
||||
output = denoised_latents / self.sd.vae.config['scaling_factor']
|
||||
output = self.sd.vae.decode(output).sample
|
||||
|
||||
if self.train_config.show_turbo_outputs:
|
||||
# since we are completely denoising, we can show them here
|
||||
with torch.no_grad():
|
||||
show_tensors(output)
|
||||
|
||||
# we return our big partial step denoised latents as our pred and our untouched latents as our target.
|
||||
# you can do mse against the two here or run the denoised through the vae for pixel space loss against the
|
||||
# input tensor images.
|
||||
|
||||
return output, batch.tensor.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
|
||||
|
||||
# you can expand these in a child class to make customization easier
|
||||
def calculate_loss(
|
||||
self,
|
||||
@@ -96,6 +200,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
|
||||
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
|
||||
assert not self.train_config.train_turbo
|
||||
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
||||
stretched_mask_multiplier = value_map(
|
||||
mask_multiplier,
|
||||
@@ -114,6 +219,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# set masked multiplier to 1.0 so we dont double apply it
|
||||
# mask_multiplier = 1.0
|
||||
elif prior_pred is not None:
|
||||
assert not self.train_config.train_turbo
|
||||
# matching adapter prediction
|
||||
target = prior_pred
|
||||
elif self.sd.prediction_type == 'v_prediction':
|
||||
@@ -124,9 +230,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
pred = noise_pred
|
||||
|
||||
if self.train_config.train_turbo:
|
||||
pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch)
|
||||
|
||||
ignore_snr = False
|
||||
|
||||
if loss_target == 'source' or loss_target == 'unaugmented':
|
||||
assert not self.train_config.train_turbo
|
||||
# ignore_snr = True
|
||||
if batch.sigmas is None:
|
||||
raise ValueError("Batch sigmas is None. This should not happen")
|
||||
@@ -164,6 +274,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
prior_loss = None
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
|
||||
assert not self.train_config.train_turbo
|
||||
# to a loss to unmasked areas of the prior for unmasked regularization
|
||||
prior_loss = torch.nn.functional.mse_loss(
|
||||
prior_pred.float(),
|
||||
@@ -178,15 +289,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss = loss + prior_loss
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.learnable_snr_gos:
|
||||
# add snr_gamma
|
||||
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
|
||||
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
|
||||
# add snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
|
||||
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
if not self.train_config.train_turbo:
|
||||
if self.train_config.learnable_snr_gos:
|
||||
# add snr_gamma
|
||||
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
|
||||
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
|
||||
# add snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
|
||||
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
|
||||
@@ -198,6 +198,8 @@ class TrainConfig:
|
||||
self.train_unet = kwargs.get('train_unet', True)
|
||||
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
||||
self.train_refiner = kwargs.get('train_refiner', True)
|
||||
self.train_turbo = kwargs.get('train_turbo', False)
|
||||
self.show_turbo_outputs = kwargs.get('show_turbo_outputs', False)
|
||||
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
|
||||
self.snr_gamma = kwargs.get('snr_gamma', None)
|
||||
# trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials
|
||||
@@ -263,6 +265,9 @@ class TrainConfig:
|
||||
self.standardize_images = kwargs.get('standardize_images', False)
|
||||
self.standardize_latents = kwargs.get('standardize_latents', False)
|
||||
|
||||
if self.train_turbo and not self.noise_scheduler.startswith("euler"):
|
||||
raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers")
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -5,6 +5,7 @@ import json
|
||||
import os
|
||||
import io
|
||||
import struct
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
@@ -425,43 +426,63 @@ def main(argv=None):
|
||||
|
||||
|
||||
is_window_shown = False
|
||||
display_lock = threading.Lock()
|
||||
current_img = None
|
||||
update_event = threading.Event()
|
||||
|
||||
def update_image(img, name):
|
||||
global current_img
|
||||
with display_lock:
|
||||
current_img = (img, name)
|
||||
update_event.set()
|
||||
|
||||
def display_image_in_thread():
|
||||
global is_window_shown
|
||||
|
||||
def display_img():
|
||||
global current_img
|
||||
while True:
|
||||
update_event.wait()
|
||||
with display_lock:
|
||||
if current_img:
|
||||
img, name = current_img
|
||||
cv2.imshow(name, img)
|
||||
current_img = None
|
||||
update_event.clear()
|
||||
if cv2.waitKey(1) & 0xFF == 27: # Esc key to stop
|
||||
cv2.destroyAllWindows()
|
||||
print('\nESC pressed, stopping')
|
||||
break
|
||||
|
||||
if not is_window_shown:
|
||||
is_window_shown = True
|
||||
threading.Thread(target=display_img, daemon=True).start()
|
||||
|
||||
|
||||
def show_img(img, name='AI Toolkit'):
|
||||
global is_window_shown
|
||||
|
||||
img = np.clip(img, 0, 255).astype(np.uint8)
|
||||
cv2.imshow(name, img[:, :, ::-1])
|
||||
k = cv2.waitKey(1) & 0xFF
|
||||
if k == 27: # Esc key to stop
|
||||
print('\nESC pressed, stopping')
|
||||
raise KeyboardInterrupt
|
||||
update_image(img[:, :, ::-1], name)
|
||||
if not is_window_shown:
|
||||
is_window_shown = True
|
||||
|
||||
display_image_in_thread()
|
||||
|
||||
|
||||
def show_tensors(imgs: torch.Tensor, name='AI Toolkit'):
|
||||
# if rank is 4
|
||||
if len(imgs.shape) == 4:
|
||||
img_list = torch.chunk(imgs, imgs.shape[0], dim=0)
|
||||
else:
|
||||
img_list = [imgs]
|
||||
# put images side by side
|
||||
|
||||
img = torch.cat(img_list, dim=3)
|
||||
# img is -1 to 1, convert to 0 to 255
|
||||
img = img / 2 + 0.5
|
||||
img_numpy = img.to(torch.float32).detach().cpu().numpy()
|
||||
img_numpy = np.clip(img_numpy, 0, 1) * 255
|
||||
# convert to numpy Move channel to last
|
||||
img_numpy = img_numpy.transpose(0, 2, 3, 1)
|
||||
# convert to uint8
|
||||
img_numpy = img_numpy.astype(np.uint8)
|
||||
|
||||
show_img(img_numpy[0], name=name)
|
||||
|
||||
|
||||
def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'):
|
||||
# decode latents
|
||||
if vae.device == 'cpu':
|
||||
vae.to(latents.device)
|
||||
latents = latents / vae.config['scaling_factor']
|
||||
@@ -469,7 +490,6 @@ def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit
|
||||
show_tensors(imgs, name=name)
|
||||
|
||||
|
||||
|
||||
def on_exit():
|
||||
if is_window_shown:
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Union, List, Optional, Dict, Any, Tuple, Callable
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
|
||||
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||
# from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from diffusers.utils import is_torch_xla_available
|
||||
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
|
||||
@@ -43,13 +43,14 @@ class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.sampler = None
|
||||
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
||||
model = ModelWrapper(unet, scheduler.alphas_cumprod)
|
||||
if scheduler.config.prediction_type == "v_prediction":
|
||||
self.k_diffusion_model = CompVisVDenoiser(model)
|
||||
else:
|
||||
self.k_diffusion_model = CompVisDenoiser(model)
|
||||
raise NotImplementedError("This pipeline is not implemented yet")
|
||||
# self.sampler = None
|
||||
# scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
||||
# model = ModelWrapper(unet, scheduler.alphas_cumprod)
|
||||
# if scheduler.config.prediction_type == "v_prediction":
|
||||
# self.k_diffusion_model = CompVisVDenoiser(model)
|
||||
# else:
|
||||
# self.k_diffusion_model = CompVisDenoiser(model)
|
||||
|
||||
def set_scheduler(self, scheduler_type: str):
|
||||
library = importlib.import_module("k_diffusion")
|
||||
|
||||
@@ -755,7 +755,7 @@ class StableDiffusion:
|
||||
add_time_ids=None,
|
||||
conditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||
unconditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||
is_input_scaled=True,
|
||||
is_input_scaled=False,
|
||||
**kwargs,
|
||||
):
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user