Initial commit

This commit is contained in:
Jaret Burkett
2023-12-29 13:07:35 -07:00
parent 0892dec4a5
commit bafacf3b65
5 changed files with 175 additions and 37 deletions

View File

@@ -198,6 +198,8 @@ class TrainConfig:
self.train_unet = kwargs.get('train_unet', True)
self.train_text_encoder = kwargs.get('train_text_encoder', True)
self.train_refiner = kwargs.get('train_refiner', True)
self.train_turbo = kwargs.get('train_turbo', False)
self.show_turbo_outputs = kwargs.get('show_turbo_outputs', False)
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
self.snr_gamma = kwargs.get('snr_gamma', None)
# trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials
@@ -263,6 +265,9 @@ class TrainConfig:
self.standardize_images = kwargs.get('standardize_images', False)
self.standardize_latents = kwargs.get('standardize_latents', False)
if self.train_turbo and not self.noise_scheduler.startswith("euler"):
raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers")
class ModelConfig:
def __init__(self, **kwargs):

View File

@@ -5,6 +5,7 @@ import json
import os
import io
import struct
import threading
from typing import TYPE_CHECKING
import cv2
@@ -425,43 +426,63 @@ def main(argv=None):
is_window_shown = False
display_lock = threading.Lock()
current_img = None
update_event = threading.Event()
def update_image(img, name):
global current_img
with display_lock:
current_img = (img, name)
update_event.set()
def display_image_in_thread():
global is_window_shown
def display_img():
global current_img
while True:
update_event.wait()
with display_lock:
if current_img:
img, name = current_img
cv2.imshow(name, img)
current_img = None
update_event.clear()
if cv2.waitKey(1) & 0xFF == 27: # Esc key to stop
cv2.destroyAllWindows()
print('\nESC pressed, stopping')
break
if not is_window_shown:
is_window_shown = True
threading.Thread(target=display_img, daemon=True).start()
def show_img(img, name='AI Toolkit'):
global is_window_shown
img = np.clip(img, 0, 255).astype(np.uint8)
cv2.imshow(name, img[:, :, ::-1])
k = cv2.waitKey(1) & 0xFF
if k == 27: # Esc key to stop
print('\nESC pressed, stopping')
raise KeyboardInterrupt
update_image(img[:, :, ::-1], name)
if not is_window_shown:
is_window_shown = True
display_image_in_thread()
def show_tensors(imgs: torch.Tensor, name='AI Toolkit'):
# if rank is 4
if len(imgs.shape) == 4:
img_list = torch.chunk(imgs, imgs.shape[0], dim=0)
else:
img_list = [imgs]
# put images side by side
img = torch.cat(img_list, dim=3)
# img is -1 to 1, convert to 0 to 255
img = img / 2 + 0.5
img_numpy = img.to(torch.float32).detach().cpu().numpy()
img_numpy = np.clip(img_numpy, 0, 1) * 255
# convert to numpy Move channel to last
img_numpy = img_numpy.transpose(0, 2, 3, 1)
# convert to uint8
img_numpy = img_numpy.astype(np.uint8)
show_img(img_numpy[0], name=name)
def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'):
# decode latents
if vae.device == 'cpu':
vae.to(latents.device)
latents = latents / vae.config['scaling_factor']
@@ -469,7 +490,6 @@ def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit
show_tensors(imgs, name=name)
def on_exit():
if is_window_shown:
cv2.destroyAllWindows()

View File

@@ -5,8 +5,8 @@ from typing import Union, List, Optional, Dict, Any, Tuple, Callable
import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
# from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from diffusers.utils import is_torch_xla_available
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
@@ -43,13 +43,14 @@ class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline):
unet=unet,
scheduler=scheduler,
)
self.sampler = None
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
model = ModelWrapper(unet, scheduler.alphas_cumprod)
if scheduler.config.prediction_type == "v_prediction":
self.k_diffusion_model = CompVisVDenoiser(model)
else:
self.k_diffusion_model = CompVisDenoiser(model)
raise NotImplementedError("This pipeline is not implemented yet")
# self.sampler = None
# scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
# model = ModelWrapper(unet, scheduler.alphas_cumprod)
# if scheduler.config.prediction_type == "v_prediction":
# self.k_diffusion_model = CompVisVDenoiser(model)
# else:
# self.k_diffusion_model = CompVisDenoiser(model)
def set_scheduler(self, scheduler_type: str):
library = importlib.import_module("k_diffusion")

View File

@@ -755,7 +755,7 @@ class StableDiffusion:
add_time_ids=None,
conditional_embeddings: Union[PromptEmbeds, None] = None,
unconditional_embeddings: Union[PromptEmbeds, None] = None,
is_input_scaled=True,
is_input_scaled=False,
**kwargs,
):
with torch.no_grad():