Added some features for an LCM condenser plugin

This commit is contained in:
Jaret Burkett
2023-11-15 08:56:45 -07:00
parent 4f9cdd916a
commit e47006ed70
4 changed files with 80 additions and 16 deletions

View File

@@ -656,9 +656,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
with self.timer('prepare_noise'): with self.timer('prepare_noise'):
self.sd.noise_scheduler.set_timesteps( if self.train_config.noise_scheduler == 'lcm':
1000, device=self.device_torch self.sd.noise_scheduler.set_timesteps(
) 1000, device=self.device_torch, original_inference_steps=1000
)
else:
self.sd.noise_scheduler.set_timesteps(
1000, device=self.device_torch
)
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content': # if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
if self.train_config.content_or_style in ['style', 'content']: if self.train_config.content_or_style in ['style', 'content']:
@@ -1136,6 +1141,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename) optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename)
if os.path.exists(optimizer_state_file_path): if os.path.exists(optimizer_state_file_path):
# try to load # try to load
# previous param groups
# previous_params = copy.deepcopy(optimizer.param_groups)
try: try:
print(f"Loading optimizer state from {optimizer_state_file_path}") print(f"Loading optimizer state from {optimizer_state_file_path}")
optimizer_state_dict = torch.load(optimizer_state_file_path) optimizer_state_dict = torch.load(optimizer_state_file_path)
@@ -1144,6 +1151,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
print(f"Failed to load optimizer state from {optimizer_state_file_path}") print(f"Failed to load optimizer state from {optimizer_state_file_path}")
print(e) print(e)
# Update the learning rates if they changed
# optimizer.param_groups = previous_params
lr_scheduler_params = self.train_config.lr_scheduler_params lr_scheduler_params = self.train_config.lr_scheduler_params
# make sure it had bare minimum # make sure it had bare minimum

View File

@@ -635,9 +635,9 @@ class MaskFileItemDTOMixin:
# do a flip # do a flip
img.transpose(Image.FLIP_TOP_BOTTOM) img.transpose(Image.FLIP_TOP_BOTTOM)
# randomly apply a blur up to 2% of the size of the min (width, height) # randomly apply a blur up to 0.5% of the size of the min (width, height)
min_size = min(img.width, img.height) min_size = min(img.width, img.height)
blur_radius = int(min_size * random.random() * 0.02) blur_radius = int(min_size * random.random() * 0.005)
img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
# make grayscale # make grayscale
@@ -794,7 +794,6 @@ class PoiFileItemDTOMixin:
except Exception as e: except Exception as e:
pass pass
# handle flipping # handle flipping
if kwargs.get('flip_x', False): if kwargs.get('flip_x', False):
# flip the poi # flip the poi
@@ -825,7 +824,28 @@ class PoiFileItemDTOMixin:
poi_width = int(self.poi_width * self.dataset_config.scale) poi_width = int(self.poi_width * self.dataset_config.scale)
poi_height = int(self.poi_height * self.dataset_config.scale) poi_height = int(self.poi_height * self.dataset_config.scale)
# determine new cropping # expand poi to fit resolution
if poi_width < resolution:
width_difference = resolution - poi_width
poi_x = poi_x - int(width_difference / 2)
poi_width = resolution
# make sure we dont go out of bounds
if poi_x < 0:
poi_x = 0
# if total width too much, crop
if poi_x + poi_width > initial_width:
poi_width = initial_width - poi_x
if poi_height < resolution:
height_difference = resolution - poi_height
poi_y = poi_y - int(height_difference / 2)
poi_height = resolution
# make sure we dont go out of bounds
if poi_y < 0:
poi_y = 0
# if total height too much, crop
if poi_y + poi_height > initial_height:
poi_height = initial_height - poi_y
# crop left # crop left
if poi_x > 0: if poi_x > 0:

View File

@@ -5,9 +5,12 @@ import json
import os import os
import io import io
import struct import struct
from typing import TYPE_CHECKING
import cv2 import cv2
import numpy as np import numpy as np
import torch
from diffusers import AutoencoderTiny
FILE_UNKNOWN = "Sorry, don't know how to get size for this file." FILE_UNKNOWN = "Sorry, don't know how to get size for this file."
@@ -424,23 +427,47 @@ def main(argv=None):
is_window_shown = False is_window_shown = False
def show_img(img): def show_img(img, name='AI Toolkit'):
global is_window_shown global is_window_shown
img = np.clip(img, 0, 255).astype(np.uint8) img = np.clip(img, 0, 255).astype(np.uint8)
cv2.imshow('AI Toolkit', img[:, :, ::-1]) cv2.imshow(name, img[:, :, ::-1])
k = cv2.waitKey(10) & 0xFF k = cv2.waitKey(10) & 0xFF
if k == 27: # Esc key to stop if k == 27: # Esc key to stop
print('\nESC pressed, stopping') print('\nESC pressed, stopping')
raise KeyboardInterrupt raise KeyboardInterrupt
# show again to initialize the window if first
if not is_window_shown: if not is_window_shown:
cv2.imshow('AI Toolkit', img[:, :, ::-1]) is_window_shown = True
k = cv2.waitKey(10) & 0xFF
if k == 27: # Esc key to stop
print('\nESC pressed, stopping')
raise KeyboardInterrupt def show_tensors(imgs: torch.Tensor, name='AI Toolkit'):
is_window_shown = True # if rank is 4
if len(imgs.shape) == 4:
img_list = torch.chunk(imgs, imgs.shape[0], dim=0)
else:
img_list = [imgs]
# put images side by side
img = torch.cat(img_list, dim=3)
# img is -1 to 1, convert to 0 to 255
img = img / 2 + 0.5
img_numpy = img.to(torch.float32).detach().cpu().numpy()
img_numpy = np.clip(img_numpy, 0, 1) * 255
# convert to numpy Move channel to last
img_numpy = img_numpy.transpose(0, 2, 3, 1)
# convert to uint8
img_numpy = img_numpy.astype(np.uint8)
show_img(img_numpy[0], name=name)
def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'):
# decode latents
if vae.device == 'cpu':
vae.to(latents.device)
latents = latents / vae.config['scaling_factor']
imgs = vae.decode(latents).sample
show_tensors(imgs, name=name)
def on_exit(): def on_exit():

View File

@@ -14,6 +14,7 @@ from PIL import Image
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from safetensors.torch import save_file, load_file from safetensors.torch import save_file, load_file
from torch.nn import Parameter from torch.nn import Parameter
from torch.utils.checkpoint import checkpoint
from tqdm import tqdm from tqdm import tqdm
from torchvision.transforms import Resize, transforms from torchvision.transforms import Resize, transforms
@@ -828,6 +829,8 @@ class StableDiffusion:
start_timesteps=0, start_timesteps=0,
guidance_scale=1, guidance_scale=1,
add_time_ids=None, add_time_ids=None,
bleed_ratio: float = 0.5,
bleed_latents: torch.FloatTensor = None,
**kwargs, **kwargs,
): ):
@@ -842,6 +845,10 @@ class StableDiffusion:
) )
latents = self.noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] latents = self.noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
# if not last step, and bleeding, bleed in some latents
if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]:
latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio)
# return latents_steps # return latents_steps
return latents return latents