diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 0f17ce40..fff81a75 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -656,9 +656,14 @@ class BaseSDTrainProcess(BaseTrainProcess): with self.timer('prepare_noise'): - self.sd.noise_scheduler.set_timesteps( - 1000, device=self.device_torch - ) + if self.train_config.noise_scheduler == 'lcm': + self.sd.noise_scheduler.set_timesteps( + 1000, device=self.device_torch, original_inference_steps=1000 + ) + else: + self.sd.noise_scheduler.set_timesteps( + 1000, device=self.device_torch + ) # if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content': if self.train_config.content_or_style in ['style', 'content']: @@ -1136,6 +1141,8 @@ class BaseSDTrainProcess(BaseTrainProcess): optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename) if os.path.exists(optimizer_state_file_path): # try to load + # previous param groups + # previous_params = copy.deepcopy(optimizer.param_groups) try: print(f"Loading optimizer state from {optimizer_state_file_path}") optimizer_state_dict = torch.load(optimizer_state_file_path) @@ -1144,6 +1151,9 @@ class BaseSDTrainProcess(BaseTrainProcess): print(f"Failed to load optimizer state from {optimizer_state_file_path}") print(e) + # Update the learning rates if they changed + # optimizer.param_groups = previous_params + lr_scheduler_params = self.train_config.lr_scheduler_params # make sure it had bare minimum diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index fc26d77d..cca8ced1 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -635,9 +635,9 @@ class MaskFileItemDTOMixin: # do a flip img.transpose(Image.FLIP_TOP_BOTTOM) - # randomly apply a blur up to 2% of the size of the min (width, height) + # randomly apply a blur up to 0.5% of the size of the min (width, height) min_size = min(img.width, img.height) - blur_radius = int(min_size * random.random() * 0.02) + blur_radius = int(min_size * random.random() * 0.005) img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) # make grayscale @@ -794,7 +794,6 @@ class PoiFileItemDTOMixin: except Exception as e: pass - # handle flipping if kwargs.get('flip_x', False): # flip the poi @@ -825,7 +824,28 @@ class PoiFileItemDTOMixin: poi_width = int(self.poi_width * self.dataset_config.scale) poi_height = int(self.poi_height * self.dataset_config.scale) - # determine new cropping + # expand poi to fit resolution + if poi_width < resolution: + width_difference = resolution - poi_width + poi_x = poi_x - int(width_difference / 2) + poi_width = resolution + # make sure we dont go out of bounds + if poi_x < 0: + poi_x = 0 + # if total width too much, crop + if poi_x + poi_width > initial_width: + poi_width = initial_width - poi_x + + if poi_height < resolution: + height_difference = resolution - poi_height + poi_y = poi_y - int(height_difference / 2) + poi_height = resolution + # make sure we dont go out of bounds + if poi_y < 0: + poi_y = 0 + # if total height too much, crop + if poi_y + poi_height > initial_height: + poi_height = initial_height - poi_y # crop left if poi_x > 0: diff --git a/toolkit/image_utils.py b/toolkit/image_utils.py index 0625a030..b4dbc38d 100644 --- a/toolkit/image_utils.py +++ b/toolkit/image_utils.py @@ -5,9 +5,12 @@ import json import os import io import struct +from typing import TYPE_CHECKING import cv2 import numpy as np +import torch +from diffusers import AutoencoderTiny FILE_UNKNOWN = "Sorry, don't know how to get size for this file." @@ -424,23 +427,47 @@ def main(argv=None): is_window_shown = False -def show_img(img): +def show_img(img, name='AI Toolkit'): global is_window_shown img = np.clip(img, 0, 255).astype(np.uint8) - cv2.imshow('AI Toolkit', img[:, :, ::-1]) + cv2.imshow(name, img[:, :, ::-1]) k = cv2.waitKey(10) & 0xFF if k == 27: # Esc key to stop print('\nESC pressed, stopping') raise KeyboardInterrupt - # show again to initialize the window if first if not is_window_shown: - cv2.imshow('AI Toolkit', img[:, :, ::-1]) - k = cv2.waitKey(10) & 0xFF - if k == 27: # Esc key to stop - print('\nESC pressed, stopping') - raise KeyboardInterrupt - is_window_shown = True + is_window_shown = True + + + +def show_tensors(imgs: torch.Tensor, name='AI Toolkit'): + # if rank is 4 + if len(imgs.shape) == 4: + img_list = torch.chunk(imgs, imgs.shape[0], dim=0) + else: + img_list = [imgs] + # put images side by side + img = torch.cat(img_list, dim=3) + # img is -1 to 1, convert to 0 to 255 + img = img / 2 + 0.5 + img_numpy = img.to(torch.float32).detach().cpu().numpy() + img_numpy = np.clip(img_numpy, 0, 1) * 255 + # convert to numpy Move channel to last + img_numpy = img_numpy.transpose(0, 2, 3, 1) + # convert to uint8 + img_numpy = img_numpy.astype(np.uint8) + show_img(img_numpy[0], name=name) + + +def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'): + # decode latents + if vae.device == 'cpu': + vae.to(latents.device) + latents = latents / vae.config['scaling_factor'] + imgs = vae.decode(latents).sample + show_tensors(imgs, name=name) + def on_exit(): diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 9a72d647..0b1e9a5b 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -14,6 +14,7 @@ from PIL import Image from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from safetensors.torch import save_file, load_file from torch.nn import Parameter +from torch.utils.checkpoint import checkpoint from tqdm import tqdm from torchvision.transforms import Resize, transforms @@ -828,6 +829,8 @@ class StableDiffusion: start_timesteps=0, guidance_scale=1, add_time_ids=None, + bleed_ratio: float = 0.5, + bleed_latents: torch.FloatTensor = None, **kwargs, ): @@ -842,6 +845,10 @@ class StableDiffusion: ) latents = self.noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] + # if not last step, and bleeding, bleed in some latents + if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]: + latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio) + # return latents_steps return latents