New image generation img2img. various tweaks and fixes

This commit is contained in:
Jaret Burkett
2024-07-24 04:13:41 -06:00
parent 8d799031cf
commit 80aa2dbb80
9 changed files with 285 additions and 47 deletions

View File

@@ -0,0 +1,188 @@
import os
import random
from collections import OrderedDict
from typing import List
import numpy as np
from PIL import Image
from diffusers import T2IAdapter
from torch.utils.data import DataLoader
from diffusers import StableDiffusionXLImg2ImgPipeline
from tqdm import tqdm
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.sampler import get_sampler
from toolkit.stable_diffusion_model import StableDiffusion
import gc
import torch
from jobs.process import BaseExtensionProcess
from toolkit.data_loader import get_dataloader_from_datasets
from toolkit.train_tools import get_torch_dtype
from controlnet_aux.midas import MidasDetector
from diffusers.utils import load_image
def flush():
torch.cuda.empty_cache()
gc.collect()
class GenerateConfig:
def __init__(self, **kwargs):
self.prompts: List[str]
self.sampler = kwargs.get('sampler', 'ddpm')
self.neg = kwargs.get('neg', '')
self.seed = kwargs.get('seed', -1)
self.walk_seed = kwargs.get('walk_seed', False)
self.guidance_scale = kwargs.get('guidance_scale', 7)
self.sample_steps = kwargs.get('sample_steps', 20)
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.ext = kwargs.get('ext', 'png')
self.denoise_strength = kwargs.get('denoise_strength', 0.5)
self.trigger_word = kwargs.get('trigger_word', None)
class Img2ImgGenerator(BaseExtensionProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.output_folder = self.get_conf('output_folder', required=True)
self.copy_inputs_to = self.get_conf('copy_inputs_to', None)
self.device = self.get_conf('device', 'cuda')
self.model_config = ModelConfig(**self.get_conf('model', required=True))
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
self.is_latents_cached = True
raw_datasets = self.get_conf('datasets', None)
if raw_datasets is not None and len(raw_datasets) > 0:
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
self.datasets = None
self.datasets_reg = None
self.dtype = self.get_conf('dtype', 'float16')
self.torch_dtype = get_torch_dtype(self.dtype)
self.params = []
if raw_datasets is not None and len(raw_datasets) > 0:
for raw_dataset in raw_datasets:
dataset = DatasetConfig(**raw_dataset)
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
if not is_caching:
self.is_latents_cached = False
if dataset.is_reg:
if self.datasets_reg is None:
self.datasets_reg = []
self.datasets_reg.append(dataset)
else:
if self.datasets is None:
self.datasets = []
self.datasets.append(dataset)
self.progress_bar = None
self.sd = StableDiffusion(
device=self.device,
model_config=self.model_config,
dtype=self.dtype,
)
print(f"Using device {self.device}")
self.data_loader: DataLoader = None
self.adapter: T2IAdapter = None
def to_pil(self, img):
# image comes in -1 to 1. convert to a PIL RGB image
img = (img + 1) / 2
img = img.clamp(0, 1)
img = img[0].permute(1, 2, 0).cpu().numpy()
img = (img * 255).astype(np.uint8)
image = Image.fromarray(img)
return image
def run(self):
with torch.no_grad():
super().run()
print("Loading model...")
self.sd.load_model()
device = torch.device(self.device)
if self.model_config.is_xl:
pipe = StableDiffusionXLImg2ImgPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder[0],
text_encoder_2=self.sd.text_encoder[1],
tokenizer=self.sd.tokenizer[0],
tokenizer_2=self.sd.tokenizer[1],
scheduler=get_sampler(self.generate_config.sampler),
).to(device, dtype=self.torch_dtype)
else:
raise NotImplementedError("Only XL models are supported")
pipe.set_progress_bar_config(disable=True)
# pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
# midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
num_batches = len(self.data_loader)
pbar = tqdm(total=num_batches, desc="Generating images")
seed = self.generate_config.seed
# load images from datasets, use tqdm
for i, batch in enumerate(self.data_loader):
batch: DataLoaderBatchDTO = batch
file_item: FileItemDTO = batch.file_items[0]
img_path = file_item.path
img_filename = os.path.basename(img_path)
img_filename_no_ext = os.path.splitext(img_filename)[0]
img_filename = img_filename_no_ext + '.' + self.generate_config.ext
output_path = os.path.join(self.output_folder, img_filename)
output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
if self.copy_inputs_to is not None:
output_inputs_path = os.path.join(self.copy_inputs_to, img_filename)
output_inputs_caption_path = os.path.join(self.copy_inputs_to, img_filename_no_ext + '.txt')
else:
output_inputs_path = None
output_inputs_caption_path = None
caption = batch.get_caption_list()[0]
if self.generate_config.trigger_word is not None:
caption = caption.replace('[trigger]', self.generate_config.trigger_word)
img: torch.Tensor = batch.tensor.clone()
image = self.to_pil(img)
# image.save(output_depth_path)
pipe: StableDiffusionXLImg2ImgPipeline = pipe
gen_images = pipe.__call__(
prompt=caption,
negative_prompt=self.generate_config.neg,
image=image,
num_inference_steps=self.generate_config.sample_steps,
guidance_scale=self.generate_config.guidance_scale,
strength=self.generate_config.denoise_strength,
).images[0]
os.makedirs(os.path.dirname(output_path), exist_ok=True)
gen_images.save(output_path)
# save caption
with open(output_caption_path, 'w') as f:
f.write(caption)
if output_inputs_path is not None:
os.makedirs(os.path.dirname(output_inputs_path), exist_ok=True)
image.save(output_inputs_path)
with open(output_inputs_caption_path, 'w') as f:
f.write(caption)
pbar.update(1)
batch.cleanup()
pbar.close()
print("Done generating images")
# cleanup
del self.sd
gc.collect()
torch.cuda.empty_cache()

View File

@@ -36,7 +36,24 @@ class PureLoraGenerator(Extension):
return PureLoraGenerator
# This is for generic training (LoRA, Dreambooth, FineTuning)
class Img2ImgGeneratorExtension(Extension):
# uid must be unique, it is how the extension is identified
uid = "batch_img2img"
# name is the name of the extension for printing
name = "Img2ImgGeneratorExtension"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .Img2ImgGenerator import Img2ImgGenerator
return Img2ImgGenerator
AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
AdvancedReferenceGeneratorExtension, PureLoraGenerator
AdvancedReferenceGeneratorExtension, PureLoraGenerator, Img2ImgGeneratorExtension
]

View File

@@ -483,6 +483,7 @@ class SDTrainer(BaseSDTrainProcess):
noise=noise,
sd=self.sd,
unconditional_embeds=unconditional_embeds,
scaler=self.scaler,
**kwargs
)

View File

@@ -739,6 +739,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
# add to noise
noise += noise_shift
# standardize the noise
std = noise.std(dim=(2, 3), keepdim=True)
normalizer = 1 / (std + 1e-6)
noise = noise * normalizer
return noise
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
@@ -975,14 +980,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise = noise * noise_multiplier
latents = latents * self.train_config.latent_multiplier
latent_multiplier = self.train_config.latent_multiplier
# handle adaptive scaling mased on std
if self.train_config.adaptive_scaling_factor:
std = latents.std(dim=(2, 3), keepdim=True)
normalizer = 1 / (std + 1e-6)
latent_multiplier = normalizer
latents = latents * latent_multiplier
batch.latents = latents
# normalize latents to a mean of 0 and an std of 1
# mean_zero_latents = latents - latents.mean()
# latents = mean_zero_latents / mean_zero_latents.std()
if batch.unconditional_latents is not None:
batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier

View File

@@ -80,6 +80,7 @@ class GenerateProcess(BaseProcess):
self.model_config = ModelConfig(**self.get_conf('model', required=True))
self.device = self.get_conf('device', self.job.device)
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
self.torch_dtype = get_torch_dtype(self.get_conf('dtype', 'float16'))
self.progress_bar = None
self.sd = StableDiffusion(
@@ -87,49 +88,57 @@ class GenerateProcess(BaseProcess):
model_config=self.model_config,
dtype=self.model_config.dtype,
)
print(f"Using device {self.device}")
def clean_prompt(self, prompt: str):
# remove any non alpha numeric characters or ,'" from prompt
return ''.join(e for e in prompt if e.isalnum() or e in ", '\"")
def run(self):
super().run()
print("Loading model...")
self.sd.load_model()
with torch.no_grad():
super().run()
print("Loading model...")
self.sd.load_model()
self.sd.pipeline.to(self.device, self.torch_dtype)
print("Compiling model...")
# self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
if self.generate_config.compile:
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead")
print("Compiling model...")
# self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
if self.generate_config.compile:
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead")
print(f"Generating {len(self.generate_config.prompts)} images")
# build prompt image configs
prompt_image_configs = []
for prompt in self.generate_config.prompts:
width = self.generate_config.width
height = self.generate_config.height
print(f"Generating {len(self.generate_config.prompts)} images")
# build prompt image configs
prompt_image_configs = []
for prompt in self.generate_config.prompts:
width = self.generate_config.width
height = self.generate_config.height
prompt = self.clean_prompt(prompt)
if self.generate_config.size_list is not None:
# randomly select a size
width, height = random.choice(self.generate_config.size_list)
if self.generate_config.size_list is not None:
# randomly select a size
width, height = random.choice(self.generate_config.size_list)
prompt_image_configs.append(GenerateImageConfig(
prompt=prompt,
prompt_2=self.generate_config.prompt_2,
width=width,
height=height,
num_inference_steps=self.generate_config.sample_steps,
guidance_scale=self.generate_config.guidance_scale,
negative_prompt=self.generate_config.neg,
negative_prompt_2=self.generate_config.neg_2,
seed=self.generate_config.seed,
guidance_rescale=self.generate_config.guidance_rescale,
output_ext=self.generate_config.ext,
output_folder=self.output_folder,
add_prompt_file=self.generate_config.prompt_file
))
# generate images
self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler)
prompt_image_configs.append(GenerateImageConfig(
prompt=prompt,
prompt_2=self.generate_config.prompt_2,
width=width,
height=height,
num_inference_steps=self.generate_config.sample_steps,
guidance_scale=self.generate_config.guidance_scale,
negative_prompt=self.generate_config.neg,
negative_prompt_2=self.generate_config.neg_2,
seed=self.generate_config.seed,
guidance_rescale=self.generate_config.guidance_rescale,
output_ext=self.generate_config.ext,
output_folder=self.output_folder,
add_prompt_file=self.generate_config.prompt_file
))
# generate images
self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler)
print("Done generating images")
# cleanup
del self.sd
gc.collect()
torch.cuda.empty_cache()
print("Done generating images")
# cleanup
del self.sd
gc.collect()
torch.cuda.empty_cache()

View File

@@ -266,6 +266,8 @@ class TrainConfig:
self.reg_weight = kwargs.get('reg_weight', 1.0)
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
# automatically adapte the vae scaling based on the image norm
self.adaptive_scaling_factor = kwargs.get('adaptive_scaling_factor', False)
# dropout that happens before encoding. It functions independently per text encoder
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)

View File

@@ -387,7 +387,7 @@ class CaptionProcessingDTOMixin:
# join back together
caption = ', '.join(token_list)
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
# caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
if self.dataset_config.random_triggers:
num_triggers = self.dataset_config.random_triggers_max

View File

@@ -407,6 +407,7 @@ def get_guided_loss_polarity(
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
scaler=None,
**kwargs
):
dtype = get_torch_dtype(sd.torch_dtype)
@@ -473,7 +474,10 @@ def get_guided_loss_polarity(
loss = loss.mean([1, 2, 3])
loss = loss.mean()
loss.backward()
if scaler is not None:
scaler.scale(loss).backward()
else:
loss.backward()
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
@@ -590,6 +594,7 @@ def get_guidance_loss(
unconditional_embeds: Optional[PromptEmbeds] = None,
mask_multiplier=None,
prior_pred=None,
scaler=None,
**kwargs
):
# TODO add others and process individual batch items separately
@@ -621,6 +626,7 @@ def get_guidance_loss(
batch,
noise,
sd,
scaler=scaler,
**kwargs
)
elif guidance_type == "tnt":

View File

@@ -41,9 +41,12 @@ sd_config = {
"prediction_type": "epsilon",
"sample_max_value": 1.0,
"set_alpha_to_one": False,
"skip_prk_steps": False,
"steps_offset": 1,
"timestep_spacing": "trailing",
# "skip_prk_steps": False, # for training
"skip_prk_steps": True,
# "steps_offset": 1,
"steps_offset": 0,
# "timestep_spacing": "trailing", # for training
"timestep_spacing": "leading",
"trained_betas": None
}