mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Added some features for an LCM condenser plugin
This commit is contained in:
@@ -656,9 +656,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
with self.timer('prepare_noise'):
|
with self.timer('prepare_noise'):
|
||||||
|
|
||||||
self.sd.noise_scheduler.set_timesteps(
|
if self.train_config.noise_scheduler == 'lcm':
|
||||||
1000, device=self.device_torch
|
self.sd.noise_scheduler.set_timesteps(
|
||||||
)
|
1000, device=self.device_torch, original_inference_steps=1000
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.sd.noise_scheduler.set_timesteps(
|
||||||
|
1000, device=self.device_torch
|
||||||
|
)
|
||||||
|
|
||||||
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
|
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
|
||||||
if self.train_config.content_or_style in ['style', 'content']:
|
if self.train_config.content_or_style in ['style', 'content']:
|
||||||
@@ -1136,6 +1141,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename)
|
optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename)
|
||||||
if os.path.exists(optimizer_state_file_path):
|
if os.path.exists(optimizer_state_file_path):
|
||||||
# try to load
|
# try to load
|
||||||
|
# previous param groups
|
||||||
|
# previous_params = copy.deepcopy(optimizer.param_groups)
|
||||||
try:
|
try:
|
||||||
print(f"Loading optimizer state from {optimizer_state_file_path}")
|
print(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||||
optimizer_state_dict = torch.load(optimizer_state_file_path)
|
optimizer_state_dict = torch.load(optimizer_state_file_path)
|
||||||
@@ -1144,6 +1151,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
# Update the learning rates if they changed
|
||||||
|
# optimizer.param_groups = previous_params
|
||||||
|
|
||||||
lr_scheduler_params = self.train_config.lr_scheduler_params
|
lr_scheduler_params = self.train_config.lr_scheduler_params
|
||||||
|
|
||||||
# make sure it had bare minimum
|
# make sure it had bare minimum
|
||||||
|
|||||||
@@ -635,9 +635,9 @@ class MaskFileItemDTOMixin:
|
|||||||
# do a flip
|
# do a flip
|
||||||
img.transpose(Image.FLIP_TOP_BOTTOM)
|
img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||||
|
|
||||||
# randomly apply a blur up to 2% of the size of the min (width, height)
|
# randomly apply a blur up to 0.5% of the size of the min (width, height)
|
||||||
min_size = min(img.width, img.height)
|
min_size = min(img.width, img.height)
|
||||||
blur_radius = int(min_size * random.random() * 0.02)
|
blur_radius = int(min_size * random.random() * 0.005)
|
||||||
img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
|
img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
|
||||||
|
|
||||||
# make grayscale
|
# make grayscale
|
||||||
@@ -794,7 +794,6 @@ class PoiFileItemDTOMixin:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# handle flipping
|
# handle flipping
|
||||||
if kwargs.get('flip_x', False):
|
if kwargs.get('flip_x', False):
|
||||||
# flip the poi
|
# flip the poi
|
||||||
@@ -825,7 +824,28 @@ class PoiFileItemDTOMixin:
|
|||||||
poi_width = int(self.poi_width * self.dataset_config.scale)
|
poi_width = int(self.poi_width * self.dataset_config.scale)
|
||||||
poi_height = int(self.poi_height * self.dataset_config.scale)
|
poi_height = int(self.poi_height * self.dataset_config.scale)
|
||||||
|
|
||||||
# determine new cropping
|
# expand poi to fit resolution
|
||||||
|
if poi_width < resolution:
|
||||||
|
width_difference = resolution - poi_width
|
||||||
|
poi_x = poi_x - int(width_difference / 2)
|
||||||
|
poi_width = resolution
|
||||||
|
# make sure we dont go out of bounds
|
||||||
|
if poi_x < 0:
|
||||||
|
poi_x = 0
|
||||||
|
# if total width too much, crop
|
||||||
|
if poi_x + poi_width > initial_width:
|
||||||
|
poi_width = initial_width - poi_x
|
||||||
|
|
||||||
|
if poi_height < resolution:
|
||||||
|
height_difference = resolution - poi_height
|
||||||
|
poi_y = poi_y - int(height_difference / 2)
|
||||||
|
poi_height = resolution
|
||||||
|
# make sure we dont go out of bounds
|
||||||
|
if poi_y < 0:
|
||||||
|
poi_y = 0
|
||||||
|
# if total height too much, crop
|
||||||
|
if poi_y + poi_height > initial_height:
|
||||||
|
poi_height = initial_height - poi_y
|
||||||
|
|
||||||
# crop left
|
# crop left
|
||||||
if poi_x > 0:
|
if poi_x > 0:
|
||||||
|
|||||||
@@ -5,9 +5,12 @@ import json
|
|||||||
import os
|
import os
|
||||||
import io
|
import io
|
||||||
import struct
|
import struct
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from diffusers import AutoencoderTiny
|
||||||
|
|
||||||
FILE_UNKNOWN = "Sorry, don't know how to get size for this file."
|
FILE_UNKNOWN = "Sorry, don't know how to get size for this file."
|
||||||
|
|
||||||
@@ -424,23 +427,47 @@ def main(argv=None):
|
|||||||
is_window_shown = False
|
is_window_shown = False
|
||||||
|
|
||||||
|
|
||||||
def show_img(img):
|
def show_img(img, name='AI Toolkit'):
|
||||||
global is_window_shown
|
global is_window_shown
|
||||||
|
|
||||||
img = np.clip(img, 0, 255).astype(np.uint8)
|
img = np.clip(img, 0, 255).astype(np.uint8)
|
||||||
cv2.imshow('AI Toolkit', img[:, :, ::-1])
|
cv2.imshow(name, img[:, :, ::-1])
|
||||||
k = cv2.waitKey(10) & 0xFF
|
k = cv2.waitKey(10) & 0xFF
|
||||||
if k == 27: # Esc key to stop
|
if k == 27: # Esc key to stop
|
||||||
print('\nESC pressed, stopping')
|
print('\nESC pressed, stopping')
|
||||||
raise KeyboardInterrupt
|
raise KeyboardInterrupt
|
||||||
# show again to initialize the window if first
|
|
||||||
if not is_window_shown:
|
if not is_window_shown:
|
||||||
cv2.imshow('AI Toolkit', img[:, :, ::-1])
|
is_window_shown = True
|
||||||
k = cv2.waitKey(10) & 0xFF
|
|
||||||
if k == 27: # Esc key to stop
|
|
||||||
print('\nESC pressed, stopping')
|
|
||||||
raise KeyboardInterrupt
|
def show_tensors(imgs: torch.Tensor, name='AI Toolkit'):
|
||||||
is_window_shown = True
|
# if rank is 4
|
||||||
|
if len(imgs.shape) == 4:
|
||||||
|
img_list = torch.chunk(imgs, imgs.shape[0], dim=0)
|
||||||
|
else:
|
||||||
|
img_list = [imgs]
|
||||||
|
# put images side by side
|
||||||
|
img = torch.cat(img_list, dim=3)
|
||||||
|
# img is -1 to 1, convert to 0 to 255
|
||||||
|
img = img / 2 + 0.5
|
||||||
|
img_numpy = img.to(torch.float32).detach().cpu().numpy()
|
||||||
|
img_numpy = np.clip(img_numpy, 0, 1) * 255
|
||||||
|
# convert to numpy Move channel to last
|
||||||
|
img_numpy = img_numpy.transpose(0, 2, 3, 1)
|
||||||
|
# convert to uint8
|
||||||
|
img_numpy = img_numpy.astype(np.uint8)
|
||||||
|
show_img(img_numpy[0], name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'):
|
||||||
|
# decode latents
|
||||||
|
if vae.device == 'cpu':
|
||||||
|
vae.to(latents.device)
|
||||||
|
latents = latents / vae.config['scaling_factor']
|
||||||
|
imgs = vae.decode(latents).sample
|
||||||
|
show_tensors(imgs, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def on_exit():
|
def on_exit():
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from PIL import Image
|
|||||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||||
from safetensors.torch import save_file, load_file
|
from safetensors.torch import save_file, load_file
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torchvision.transforms import Resize, transforms
|
from torchvision.transforms import Resize, transforms
|
||||||
|
|
||||||
@@ -828,6 +829,8 @@ class StableDiffusion:
|
|||||||
start_timesteps=0,
|
start_timesteps=0,
|
||||||
guidance_scale=1,
|
guidance_scale=1,
|
||||||
add_time_ids=None,
|
add_time_ids=None,
|
||||||
|
bleed_ratio: float = 0.5,
|
||||||
|
bleed_latents: torch.FloatTensor = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
@@ -842,6 +845,10 @@ class StableDiffusion:
|
|||||||
)
|
)
|
||||||
latents = self.noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
|
latents = self.noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
|
||||||
|
|
||||||
|
# if not last step, and bleeding, bleed in some latents
|
||||||
|
if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]:
|
||||||
|
latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio)
|
||||||
|
|
||||||
# return latents_steps
|
# return latents_steps
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user