Added some features for an LCM condenser plugin

This commit is contained in:
Jaret Burkett
2023-11-15 08:56:45 -07:00
parent 4f9cdd916a
commit e47006ed70
4 changed files with 80 additions and 16 deletions

View File

@@ -656,9 +656,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
with self.timer('prepare_noise'):
self.sd.noise_scheduler.set_timesteps(
1000, device=self.device_torch
)
if self.train_config.noise_scheduler == 'lcm':
self.sd.noise_scheduler.set_timesteps(
1000, device=self.device_torch, original_inference_steps=1000
)
else:
self.sd.noise_scheduler.set_timesteps(
1000, device=self.device_torch
)
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
if self.train_config.content_or_style in ['style', 'content']:
@@ -1136,6 +1141,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename)
if os.path.exists(optimizer_state_file_path):
# try to load
# previous param groups
# previous_params = copy.deepcopy(optimizer.param_groups)
try:
print(f"Loading optimizer state from {optimizer_state_file_path}")
optimizer_state_dict = torch.load(optimizer_state_file_path)
@@ -1144,6 +1151,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
print(e)
# Update the learning rates if they changed
# optimizer.param_groups = previous_params
lr_scheduler_params = self.train_config.lr_scheduler_params
# make sure it had bare minimum

View File

@@ -635,9 +635,9 @@ class MaskFileItemDTOMixin:
# do a flip
img.transpose(Image.FLIP_TOP_BOTTOM)
# randomly apply a blur up to 2% of the size of the min (width, height)
# randomly apply a blur up to 0.5% of the size of the min (width, height)
min_size = min(img.width, img.height)
blur_radius = int(min_size * random.random() * 0.02)
blur_radius = int(min_size * random.random() * 0.005)
img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
# make grayscale
@@ -794,7 +794,6 @@ class PoiFileItemDTOMixin:
except Exception as e:
pass
# handle flipping
if kwargs.get('flip_x', False):
# flip the poi
@@ -825,7 +824,28 @@ class PoiFileItemDTOMixin:
poi_width = int(self.poi_width * self.dataset_config.scale)
poi_height = int(self.poi_height * self.dataset_config.scale)
# determine new cropping
# expand poi to fit resolution
if poi_width < resolution:
width_difference = resolution - poi_width
poi_x = poi_x - int(width_difference / 2)
poi_width = resolution
# make sure we dont go out of bounds
if poi_x < 0:
poi_x = 0
# if total width too much, crop
if poi_x + poi_width > initial_width:
poi_width = initial_width - poi_x
if poi_height < resolution:
height_difference = resolution - poi_height
poi_y = poi_y - int(height_difference / 2)
poi_height = resolution
# make sure we dont go out of bounds
if poi_y < 0:
poi_y = 0
# if total height too much, crop
if poi_y + poi_height > initial_height:
poi_height = initial_height - poi_y
# crop left
if poi_x > 0:

View File

@@ -5,9 +5,12 @@ import json
import os
import io
import struct
from typing import TYPE_CHECKING
import cv2
import numpy as np
import torch
from diffusers import AutoencoderTiny
FILE_UNKNOWN = "Sorry, don't know how to get size for this file."
@@ -424,23 +427,47 @@ def main(argv=None):
is_window_shown = False
def show_img(img):
def show_img(img, name='AI Toolkit'):
global is_window_shown
img = np.clip(img, 0, 255).astype(np.uint8)
cv2.imshow('AI Toolkit', img[:, :, ::-1])
cv2.imshow(name, img[:, :, ::-1])
k = cv2.waitKey(10) & 0xFF
if k == 27: # Esc key to stop
print('\nESC pressed, stopping')
raise KeyboardInterrupt
# show again to initialize the window if first
if not is_window_shown:
cv2.imshow('AI Toolkit', img[:, :, ::-1])
k = cv2.waitKey(10) & 0xFF
if k == 27: # Esc key to stop
print('\nESC pressed, stopping')
raise KeyboardInterrupt
is_window_shown = True
is_window_shown = True
def show_tensors(imgs: torch.Tensor, name='AI Toolkit'):
# if rank is 4
if len(imgs.shape) == 4:
img_list = torch.chunk(imgs, imgs.shape[0], dim=0)
else:
img_list = [imgs]
# put images side by side
img = torch.cat(img_list, dim=3)
# img is -1 to 1, convert to 0 to 255
img = img / 2 + 0.5
img_numpy = img.to(torch.float32).detach().cpu().numpy()
img_numpy = np.clip(img_numpy, 0, 1) * 255
# convert to numpy Move channel to last
img_numpy = img_numpy.transpose(0, 2, 3, 1)
# convert to uint8
img_numpy = img_numpy.astype(np.uint8)
show_img(img_numpy[0], name=name)
def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'):
# decode latents
if vae.device == 'cpu':
vae.to(latents.device)
latents = latents / vae.config['scaling_factor']
imgs = vae.decode(latents).sample
show_tensors(imgs, name=name)
def on_exit():

View File

@@ -14,6 +14,7 @@ from PIL import Image
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from safetensors.torch import save_file, load_file
from torch.nn import Parameter
from torch.utils.checkpoint import checkpoint
from tqdm import tqdm
from torchvision.transforms import Resize, transforms
@@ -828,6 +829,8 @@ class StableDiffusion:
start_timesteps=0,
guidance_scale=1,
add_time_ids=None,
bleed_ratio: float = 0.5,
bleed_latents: torch.FloatTensor = None,
**kwargs,
):
@@ -842,6 +845,10 @@ class StableDiffusion:
)
latents = self.noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
# if not last step, and bleeding, bleed in some latents
if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]:
latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio)
# return latents_steps
return latents