mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-08 06:29:56 +00:00
Added some features for an LCM condenser plugin
This commit is contained in:
@@ -656,9 +656,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
with self.timer('prepare_noise'):
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1000, device=self.device_torch
|
||||
)
|
||||
if self.train_config.noise_scheduler == 'lcm':
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1000, device=self.device_torch, original_inference_steps=1000
|
||||
)
|
||||
else:
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1000, device=self.device_torch
|
||||
)
|
||||
|
||||
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
|
||||
if self.train_config.content_or_style in ['style', 'content']:
|
||||
@@ -1136,6 +1141,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename)
|
||||
if os.path.exists(optimizer_state_file_path):
|
||||
# try to load
|
||||
# previous param groups
|
||||
# previous_params = copy.deepcopy(optimizer.param_groups)
|
||||
try:
|
||||
print(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||
optimizer_state_dict = torch.load(optimizer_state_file_path)
|
||||
@@ -1144,6 +1151,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||
print(e)
|
||||
|
||||
# Update the learning rates if they changed
|
||||
# optimizer.param_groups = previous_params
|
||||
|
||||
lr_scheduler_params = self.train_config.lr_scheduler_params
|
||||
|
||||
# make sure it had bare minimum
|
||||
|
||||
@@ -635,9 +635,9 @@ class MaskFileItemDTOMixin:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
# randomly apply a blur up to 2% of the size of the min (width, height)
|
||||
# randomly apply a blur up to 0.5% of the size of the min (width, height)
|
||||
min_size = min(img.width, img.height)
|
||||
blur_radius = int(min_size * random.random() * 0.02)
|
||||
blur_radius = int(min_size * random.random() * 0.005)
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
|
||||
|
||||
# make grayscale
|
||||
@@ -794,7 +794,6 @@ class PoiFileItemDTOMixin:
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
# handle flipping
|
||||
if kwargs.get('flip_x', False):
|
||||
# flip the poi
|
||||
@@ -825,7 +824,28 @@ class PoiFileItemDTOMixin:
|
||||
poi_width = int(self.poi_width * self.dataset_config.scale)
|
||||
poi_height = int(self.poi_height * self.dataset_config.scale)
|
||||
|
||||
# determine new cropping
|
||||
# expand poi to fit resolution
|
||||
if poi_width < resolution:
|
||||
width_difference = resolution - poi_width
|
||||
poi_x = poi_x - int(width_difference / 2)
|
||||
poi_width = resolution
|
||||
# make sure we dont go out of bounds
|
||||
if poi_x < 0:
|
||||
poi_x = 0
|
||||
# if total width too much, crop
|
||||
if poi_x + poi_width > initial_width:
|
||||
poi_width = initial_width - poi_x
|
||||
|
||||
if poi_height < resolution:
|
||||
height_difference = resolution - poi_height
|
||||
poi_y = poi_y - int(height_difference / 2)
|
||||
poi_height = resolution
|
||||
# make sure we dont go out of bounds
|
||||
if poi_y < 0:
|
||||
poi_y = 0
|
||||
# if total height too much, crop
|
||||
if poi_y + poi_height > initial_height:
|
||||
poi_height = initial_height - poi_y
|
||||
|
||||
# crop left
|
||||
if poi_x > 0:
|
||||
|
||||
@@ -5,9 +5,12 @@ import json
|
||||
import os
|
||||
import io
|
||||
import struct
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderTiny
|
||||
|
||||
FILE_UNKNOWN = "Sorry, don't know how to get size for this file."
|
||||
|
||||
@@ -424,23 +427,47 @@ def main(argv=None):
|
||||
is_window_shown = False
|
||||
|
||||
|
||||
def show_img(img):
|
||||
def show_img(img, name='AI Toolkit'):
|
||||
global is_window_shown
|
||||
|
||||
img = np.clip(img, 0, 255).astype(np.uint8)
|
||||
cv2.imshow('AI Toolkit', img[:, :, ::-1])
|
||||
cv2.imshow(name, img[:, :, ::-1])
|
||||
k = cv2.waitKey(10) & 0xFF
|
||||
if k == 27: # Esc key to stop
|
||||
print('\nESC pressed, stopping')
|
||||
raise KeyboardInterrupt
|
||||
# show again to initialize the window if first
|
||||
if not is_window_shown:
|
||||
cv2.imshow('AI Toolkit', img[:, :, ::-1])
|
||||
k = cv2.waitKey(10) & 0xFF
|
||||
if k == 27: # Esc key to stop
|
||||
print('\nESC pressed, stopping')
|
||||
raise KeyboardInterrupt
|
||||
is_window_shown = True
|
||||
is_window_shown = True
|
||||
|
||||
|
||||
|
||||
def show_tensors(imgs: torch.Tensor, name='AI Toolkit'):
|
||||
# if rank is 4
|
||||
if len(imgs.shape) == 4:
|
||||
img_list = torch.chunk(imgs, imgs.shape[0], dim=0)
|
||||
else:
|
||||
img_list = [imgs]
|
||||
# put images side by side
|
||||
img = torch.cat(img_list, dim=3)
|
||||
# img is -1 to 1, convert to 0 to 255
|
||||
img = img / 2 + 0.5
|
||||
img_numpy = img.to(torch.float32).detach().cpu().numpy()
|
||||
img_numpy = np.clip(img_numpy, 0, 1) * 255
|
||||
# convert to numpy Move channel to last
|
||||
img_numpy = img_numpy.transpose(0, 2, 3, 1)
|
||||
# convert to uint8
|
||||
img_numpy = img_numpy.astype(np.uint8)
|
||||
show_img(img_numpy[0], name=name)
|
||||
|
||||
|
||||
def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'):
|
||||
# decode latents
|
||||
if vae.device == 'cpu':
|
||||
vae.to(latents.device)
|
||||
latents = latents / vae.config['scaling_factor']
|
||||
imgs = vae.decode(latents).sample
|
||||
show_tensors(imgs, name=name)
|
||||
|
||||
|
||||
|
||||
def on_exit():
|
||||
|
||||
@@ -14,6 +14,7 @@ from PIL import Image
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from safetensors.torch import save_file, load_file
|
||||
from torch.nn import Parameter
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from tqdm import tqdm
|
||||
from torchvision.transforms import Resize, transforms
|
||||
|
||||
@@ -828,6 +829,8 @@ class StableDiffusion:
|
||||
start_timesteps=0,
|
||||
guidance_scale=1,
|
||||
add_time_ids=None,
|
||||
bleed_ratio: float = 0.5,
|
||||
bleed_latents: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
@@ -842,6 +845,10 @@ class StableDiffusion:
|
||||
)
|
||||
latents = self.noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
|
||||
|
||||
# if not last step, and bleeding, bleed in some latents
|
||||
if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]:
|
||||
latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio)
|
||||
|
||||
# return latents_steps
|
||||
return latents
|
||||
|
||||
|
||||
Reference in New Issue
Block a user