mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
reworked samplers. Trying to find what is wrong with diffusers sampling is sdxl
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import glob
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
@@ -11,6 +12,7 @@ from toolkit.embedding import Embedding
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.paths import CONFIG_ROOT
|
||||
from toolkit.sampler import get_sampler
|
||||
|
||||
from toolkit.scheduler import get_lr_scheduler
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
@@ -89,6 +91,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if embedding_raw is not None:
|
||||
self.embed_config = EmbeddingConfig(**embedding_raw)
|
||||
|
||||
model_config_to_load = copy.deepcopy(self.model_config)
|
||||
|
||||
if self.embed_config is None and self.network_config is None:
|
||||
# get the latest checkpoint
|
||||
# check to see if we have a latest save
|
||||
@@ -96,7 +100,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
if latest_save_path is not None:
|
||||
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
self.model_config.name_or_path = latest_save_path
|
||||
model_config_to_load.name_or_path = latest_save_path
|
||||
meta = load_metadata_from_safetensors(latest_save_path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
@@ -104,11 +108,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
# get the noise scheduler
|
||||
sampler = get_sampler(self.train_config.noise_scheduler)
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=self.model_config,
|
||||
model_config=model_config_to_load,
|
||||
dtype=self.train_config.dtype,
|
||||
custom_pipeline=self.custom_pipeline,
|
||||
noise_scheduler=sampler,
|
||||
)
|
||||
|
||||
# to hold network if there is one
|
||||
@@ -164,7 +172,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
))
|
||||
|
||||
# send to be generated
|
||||
self.sd.generate_images(gen_img_config_list)
|
||||
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
|
||||
|
||||
def update_training_metadata(self):
|
||||
o_dict = OrderedDict({
|
||||
@@ -216,10 +224,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
for file in files[:-self.save_config.max_step_saves_to_keep]:
|
||||
self.print(f"Removing old save: {file}")
|
||||
os.remove(file)
|
||||
# see if a yaml file with same name exists
|
||||
yaml_file = os.path.splitext(file)[0] + ".yaml"
|
||||
if os.path.exists(yaml_file):
|
||||
os.remove(yaml_file)
|
||||
return latest_file
|
||||
else:
|
||||
return None
|
||||
|
||||
def post_save_hook(self, save_path):
|
||||
# override in subclass
|
||||
pass
|
||||
|
||||
def save(self, step=None):
|
||||
if not os.path.exists(self.save_root):
|
||||
os.makedirs(self.save_root, exist_ok=True)
|
||||
@@ -263,6 +279,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
self.print(f"Saved to {file_path}")
|
||||
self.clean_up_saves()
|
||||
self.post_save_hook(file_path)
|
||||
|
||||
# Called before the model is loaded
|
||||
def hook_before_model_load(self):
|
||||
@@ -279,6 +296,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def before_dataset_load(self):
|
||||
pass
|
||||
|
||||
def get_params(self):
|
||||
# you can extend this in subclass to get params
|
||||
# otherwise params will be gathered through normal means
|
||||
return None
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
# return loss
|
||||
return 0.0
|
||||
@@ -445,11 +467,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
self.network.prepare_grad_etc(text_encoder, unet)
|
||||
|
||||
params = self.network.prepare_optimizer_params(
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
params = self.get_params()
|
||||
|
||||
if not params:
|
||||
params = self.network.prepare_optimizer_params(
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.network.enable_gradient_checkpointing()
|
||||
@@ -477,8 +502,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.step_num = self.embedding.step
|
||||
self.start_step = self.step_num
|
||||
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
params = self.get_params()
|
||||
if not params:
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
|
||||
else:
|
||||
# set them to train or not
|
||||
@@ -506,14 +533,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sd.text_encoder.requires_grad_(False)
|
||||
self.sd.text_encoder.eval()
|
||||
|
||||
# will only return savable weights and ones with grad
|
||||
params = self.sd.prepare_optimizer_params(
|
||||
unet=self.train_config.train_unet,
|
||||
text_encoder=self.train_config.train_text_encoder,
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
params = self.get_params()
|
||||
|
||||
if params is None:
|
||||
# will only return savable weights and ones with grad
|
||||
params = self.sd.prepare_optimizer_params(
|
||||
unet=self.train_config.train_unet,
|
||||
text_encoder=self.train_config.train_text_encoder,
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
|
||||
### HOOK ###
|
||||
params = self.hook_add_extra_train_params(params)
|
||||
|
||||
@@ -98,7 +98,7 @@ class GenerateProcess(BaseProcess):
|
||||
add_prompt_file=self.generate_config.prompt_file
|
||||
))
|
||||
# generate images
|
||||
self.sd.generate_images(prompt_image_configs)
|
||||
self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler)
|
||||
|
||||
print("Done generating images")
|
||||
# cleanup
|
||||
|
||||
@@ -15,4 +15,5 @@ accelerate
|
||||
toml
|
||||
albumentations
|
||||
pydantic
|
||||
omegaconf
|
||||
omegaconf
|
||||
k-diffusion
|
||||
@@ -22,6 +22,7 @@ class LogingConfig:
|
||||
|
||||
class SampleConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.sampler: str = kwargs.get('sampler', 'ddpm')
|
||||
self.sample_every: int = kwargs.get('sample_every', 100)
|
||||
self.width: int = kwargs.get('width', 512)
|
||||
self.height: int = kwargs.get('height', 512)
|
||||
|
||||
@@ -35,6 +35,10 @@ class LosslessLatentDecoder(nn.Module):
|
||||
return kernel
|
||||
|
||||
def forward(self, x):
|
||||
dtype = x.dtype
|
||||
if self.kernel.dtype != dtype:
|
||||
self.kernel = self.kernel.to(dtype=dtype)
|
||||
|
||||
# Deconvolve input tensor with the kernel
|
||||
return nn.functional.conv_transpose2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1)
|
||||
|
||||
@@ -70,6 +74,9 @@ class LosslessLatentEncoder(nn.Module):
|
||||
return kernel
|
||||
|
||||
def forward(self, x):
|
||||
dtype = x.dtype
|
||||
if self.kernel.dtype != dtype:
|
||||
self.kernel = self.kernel.to(dtype=dtype)
|
||||
# Convolve input tensor with the kernel
|
||||
return nn.functional.conv2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1)
|
||||
|
||||
|
||||
@@ -1,14 +1,289 @@
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Union, List, Optional, Dict, Any, Tuple, Callable
|
||||
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
|
||||
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
|
||||
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
|
||||
|
||||
|
||||
class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: 'AutoencoderKL',
|
||||
text_encoder: 'CLIPTextModel',
|
||||
text_encoder_2: 'CLIPTextModelWithProjection',
|
||||
tokenizer: 'CLIPTokenizer',
|
||||
tokenizer_2: 'CLIPTokenizer',
|
||||
unet: 'UNet2DConditionModel',
|
||||
scheduler: 'KarrasDiffusionSchedulers',
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.sampler = None
|
||||
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
||||
model = ModelWrapper(unet, scheduler.alphas_cumprod)
|
||||
if scheduler.config.prediction_type == "v_prediction":
|
||||
self.k_diffusion_model = CompVisVDenoiser(model)
|
||||
else:
|
||||
self.k_diffusion_model = CompVisDenoiser(model)
|
||||
|
||||
def set_scheduler(self, scheduler_type: str):
|
||||
library = importlib.import_module("k_diffusion")
|
||||
sampling = getattr(library, "sampling")
|
||||
self.sampler = getattr(sampling, scheduler_type)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
denoising_end: Optional[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
use_karras_sigmas: bool = False,
|
||||
):
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 7.1 Apply denoising_end
|
||||
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
- (denoising_end * self.scheduler.config.num_train_timesteps)
|
||||
)
|
||||
)
|
||||
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
||||
timesteps = timesteps[:num_inference_steps]
|
||||
|
||||
# 5. Prepare sigmas
|
||||
if use_karras_sigmas:
|
||||
sigma_min: float = self.k_diffusion_model.sigmas[0].item()
|
||||
sigma_max: float = self.k_diffusion_model.sigmas[-1].item()
|
||||
sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
sigmas = sigmas.to(device)
|
||||
else:
|
||||
sigmas = self.scheduler.sigmas
|
||||
sigmas = sigmas.to(prompt_embeds.dtype)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
latents = latents * sigmas[0]
|
||||
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
|
||||
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
|
||||
|
||||
# 7. Define model function
|
||||
def model_fn(x, t):
|
||||
latent_model_input = torch.cat([x] * 2)
|
||||
t = torch.cat([t] * 2)
|
||||
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
# noise_pred = self.unet(
|
||||
# latent_model_input,
|
||||
# t,
|
||||
# encoder_hidden_states=prompt_embeds,
|
||||
# cross_attention_kwargs=cross_attention_kwargs,
|
||||
# added_cond_kwargs=added_cond_kwargs,
|
||||
# return_dict=False,
|
||||
# )[0]
|
||||
|
||||
noise_pred = self.k_diffusion_model(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,)[0]
|
||||
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
return noise_pred
|
||||
|
||||
|
||||
# 8. Run k-diffusion solver
|
||||
sampler_kwargs = {}
|
||||
# should work without it
|
||||
noise_sampler_seed = None
|
||||
|
||||
|
||||
if "noise_sampler" in inspect.signature(self.sampler).parameters:
|
||||
min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
|
||||
sampler_kwargs["noise_sampler"] = noise_sampler
|
||||
|
||||
latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
|
||||
class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# super().__init__(*args, **kwargs)
|
||||
|
||||
def predict_noise(
|
||||
self,
|
||||
|
||||
107
toolkit/sampler.py
Normal file
107
toolkit/sampler.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import copy
|
||||
|
||||
from diffusers import (
|
||||
DDPMScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDIMScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
)
|
||||
from k_diffusion.external import CompVisDenoiser
|
||||
|
||||
# scheduler:
|
||||
SCHEDULER_LINEAR_START = 0.00085
|
||||
SCHEDULER_LINEAR_END = 0.0120
|
||||
SCHEDULER_TIMESTEPS = 1000
|
||||
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||
|
||||
sdxl_sampler_config = {
|
||||
"_class_name": "EulerDiscreteScheduler",
|
||||
"_diffusers_version": "0.19.0.dev0",
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"clip_sample": False,
|
||||
"interpolation_type": "linear",
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "epsilon",
|
||||
"sample_max_value": 1.0,
|
||||
"set_alpha_to_one": False,
|
||||
"skip_prk_steps": True,
|
||||
"steps_offset": 1,
|
||||
"timestep_spacing": "leading",
|
||||
"trained_betas": None,
|
||||
"use_karras_sigmas": False
|
||||
}
|
||||
|
||||
|
||||
def get_sampler(
|
||||
sampler: str,
|
||||
):
|
||||
sched_init_args = {}
|
||||
|
||||
if sampler.startswith("k_"):
|
||||
sched_init_args["use_karras_sigmas"] = True
|
||||
|
||||
if sampler == "ddim":
|
||||
scheduler_cls = DDIMScheduler
|
||||
elif sampler == "ddpm": # ddpm is not supported ?
|
||||
scheduler_cls = DDPMScheduler
|
||||
elif sampler == "pndm":
|
||||
scheduler_cls = PNDMScheduler
|
||||
elif sampler == "lms" or sampler == "k_lms":
|
||||
scheduler_cls = LMSDiscreteScheduler
|
||||
elif sampler == "euler" or sampler == "k_euler":
|
||||
scheduler_cls = EulerDiscreteScheduler
|
||||
elif sampler == "euler_a":
|
||||
scheduler_cls = EulerAncestralDiscreteScheduler
|
||||
elif sampler == "dpmsolver" or sampler == "dpmsolver++" or sampler == "k_dpmsolver" or sampler == "k_dpmsolver++":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args["algorithm_type"] = sampler.replace("k_", "")
|
||||
elif sampler == "dpmsingle":
|
||||
scheduler_cls = DPMSolverSinglestepScheduler
|
||||
elif sampler == "heun":
|
||||
scheduler_cls = HeunDiscreteScheduler
|
||||
elif sampler == "dpm_2":
|
||||
scheduler_cls = KDPM2DiscreteScheduler
|
||||
elif sampler == "dpm_2_a":
|
||||
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
||||
|
||||
config = copy.deepcopy(sdxl_sampler_config)
|
||||
config.update(sched_init_args)
|
||||
|
||||
scheduler = scheduler_cls.from_config(config)
|
||||
|
||||
return scheduler
|
||||
|
||||
|
||||
# testing
|
||||
if __name__ == "__main__":
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from diffusers import StableDiffusionKDiffusionPipeline
|
||||
import torch
|
||||
import os
|
||||
|
||||
inference_steps = 25
|
||||
|
||||
pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
k_diffusion_model = CompVisDenoiser(model)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "an astronaut riding a horse on mars"
|
||||
pipe.set_scheduler("sample_heun")
|
||||
generator = torch.Generator(device="cuda").manual_seed(seed)
|
||||
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
|
||||
|
||||
image.save("./astronaut_heun_k_diffusion.png")
|
||||
@@ -1,5 +1,6 @@
|
||||
import gc
|
||||
import json
|
||||
import shutil
|
||||
import typing
|
||||
from typing import Union, List, Tuple, Iterator
|
||||
import sys
|
||||
@@ -19,13 +20,15 @@ from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_ldm_model_from_diffusers
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import torch
|
||||
from library import model_util
|
||||
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
|
||||
from diffusers.schedulers import DDPMScheduler
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||
StableDiffusionKDiffusionXLPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
import diffusers
|
||||
|
||||
@@ -84,24 +87,14 @@ if typing.TYPE_CHECKING:
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline']
|
||||
vae: Union[None, 'AutoencoderKL']
|
||||
unet: Union[None, 'UNet2DConditionModel']
|
||||
text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
|
||||
tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
|
||||
noise_scheduler: Union[None, 'KarrasDiffusionSchedulers', 'DDPMScheduler']
|
||||
device: str
|
||||
dtype: str
|
||||
torch_dtype: torch.dtype
|
||||
device_torch: torch.device
|
||||
model_config: ModelConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype='fp16',
|
||||
custom_pipeline=None
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
):
|
||||
self.custom_pipeline = custom_pipeline
|
||||
self.device = device
|
||||
@@ -111,6 +104,13 @@ class StableDiffusion:
|
||||
self.model_config = model_config
|
||||
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
|
||||
|
||||
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline']
|
||||
self.vae: Union[None, 'AutoencoderKL']
|
||||
self.unet: Union[None, 'UNet2DConditionModel']
|
||||
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
|
||||
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
|
||||
self.noise_scheduler: Union[None, 'KarrasDiffusionSchedulers', 'DDPMScheduler'] = noise_scheduler
|
||||
|
||||
# sdxl stuff
|
||||
self.logit_scale = None
|
||||
self.ckppt_info = None
|
||||
@@ -124,6 +124,8 @@ class StableDiffusion:
|
||||
self.use_text_encoder_1 = model_config.use_text_encoder_1
|
||||
self.use_text_encoder_2 = model_config.use_text_encoder_2
|
||||
|
||||
self.config_file = None
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
return
|
||||
@@ -131,23 +133,25 @@ class StableDiffusion:
|
||||
|
||||
# TODO handle other schedulers
|
||||
# sch = KDPM2DiscreteScheduler
|
||||
sch = DDPMScheduler
|
||||
# do our own scheduler
|
||||
prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
|
||||
scheduler = sch(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.00085,
|
||||
beta_end=0.0120,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
prediction_type=prediction_type,
|
||||
steps_offset=1
|
||||
)
|
||||
if self.noise_scheduler is None:
|
||||
sch = DDPMScheduler
|
||||
# do our own scheduler
|
||||
prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
|
||||
scheduler = sch(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.00085,
|
||||
beta_end=0.0120,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
prediction_type=prediction_type,
|
||||
steps_offset=0
|
||||
)
|
||||
self.noise_scheduler = scheduler
|
||||
|
||||
# move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
|
||||
scheduler.betas = scheduler.betas.to(self.device_torch)
|
||||
scheduler.alphas = scheduler.alphas.to(self.device_torch)
|
||||
scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(self.device_torch)
|
||||
self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
|
||||
self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
|
||||
self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
|
||||
|
||||
model_path = self.model_config.name_or_path
|
||||
if 'civitai.com' in self.model_config.name_or_path:
|
||||
@@ -159,7 +163,8 @@ class StableDiffusion:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
else:
|
||||
pipln = CustomStableDiffusionXLPipeline
|
||||
pipln = StableDiffusionXLPipeline
|
||||
# pipln = StableDiffusionKDiffusionXLPipeline
|
||||
|
||||
# see if path exists
|
||||
if not os.path.exists(model_path):
|
||||
@@ -204,7 +209,7 @@ class StableDiffusion:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
else:
|
||||
pipln = CustomStableDiffusionPipeline
|
||||
pipln = StableDiffusionPipeline
|
||||
|
||||
# see if path exists
|
||||
if not os.path.exists(model_path):
|
||||
@@ -237,14 +242,13 @@ class StableDiffusion:
|
||||
tokenizer = pipe.tokenizer
|
||||
|
||||
# scheduler doesn't get set sometimes, so we set it here
|
||||
pipe.scheduler = scheduler
|
||||
pipe.scheduler = self.noise_scheduler
|
||||
|
||||
if self.model_config.vae_path is not None:
|
||||
external_vae = load_vae(self.model_config.vae_path, dtype)
|
||||
pipe.vae = external_vae
|
||||
|
||||
self.unet = pipe.unet
|
||||
self.noise_scheduler = pipe.scheduler
|
||||
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
@@ -257,7 +261,7 @@ class StableDiffusion:
|
||||
self.pipeline = pipe
|
||||
self.is_loaded = True
|
||||
|
||||
def generate_images(self, image_configs: List[GenerateImageConfig]):
|
||||
def generate_images(self, image_configs: List[GenerateImageConfig], sampler=None):
|
||||
# sample_folder = os.path.join(self.save_root, 'samples')
|
||||
if self.network is not None:
|
||||
self.network.eval()
|
||||
@@ -293,16 +297,31 @@ class StableDiffusion:
|
||||
self.vae.to(self.device_torch)
|
||||
self.unet.to(self.device_torch)
|
||||
|
||||
noise_scheduler = self.noise_scheduler
|
||||
if sampler is not None:
|
||||
if sampler.startswith("sample_"): # sample_dpmpp_2m
|
||||
# using ksampler
|
||||
noise_scheduler = get_sampler('lms')
|
||||
else:
|
||||
noise_scheduler = get_sampler(sampler)
|
||||
|
||||
if sampler.startswith("sample_") and self.is_xl:
|
||||
# using kdiffusion
|
||||
Pipe = StableDiffusionKDiffusionXLPipeline
|
||||
else:
|
||||
Pipe = StableDiffusionXLPipeline
|
||||
|
||||
|
||||
# TODO add clip skip
|
||||
if self.is_xl:
|
||||
pipeline = StableDiffusionXLPipeline(
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
unet=self.unet,
|
||||
text_encoder=self.text_encoder[0],
|
||||
text_encoder_2=self.text_encoder[1],
|
||||
tokenizer=self.tokenizer[0],
|
||||
tokenizer_2=self.tokenizer[1],
|
||||
scheduler=self.noise_scheduler,
|
||||
scheduler=noise_scheduler,
|
||||
add_watermarker=False,
|
||||
).to(self.device_torch)
|
||||
# force turn that (ruin your images with obvious green and red dots) the #$@@ off!!!
|
||||
@@ -313,7 +332,7 @@ class StableDiffusion:
|
||||
unet=self.unet,
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
scheduler=self.noise_scheduler,
|
||||
scheduler=noise_scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
@@ -321,6 +340,9 @@ class StableDiffusion:
|
||||
# disable progress bar
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
if sampler.startswith("sample_"):
|
||||
pipeline.set_scheduler(sampler)
|
||||
|
||||
start_multiplier = 1.0
|
||||
if self.network is not None:
|
||||
start_multiplier = self.network.multiplier
|
||||
@@ -345,8 +367,14 @@ class StableDiffusion:
|
||||
# was trained on 0.7 (I believe)
|
||||
|
||||
grs = gen_config.guidance_rescale
|
||||
if grs is None or grs < 0.00001:
|
||||
grs = 0.7
|
||||
# if grs is None or grs < 0.00001:
|
||||
# grs = 0.7
|
||||
grs = 0.0
|
||||
|
||||
extra = {}
|
||||
if sampler.startswith("sample_"):
|
||||
extra['use_karras_sigmas'] = True
|
||||
|
||||
|
||||
img = pipeline(
|
||||
prompt=gen_config.prompt,
|
||||
@@ -358,6 +386,7 @@ class StableDiffusion:
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
guidance_rescale=grs,
|
||||
**extra
|
||||
).images[0]
|
||||
else:
|
||||
img = pipeline(
|
||||
@@ -414,11 +443,11 @@ class StableDiffusion:
|
||||
noise = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
UNET_IN_CHANNELS,
|
||||
self.unet.config['in_channels'],
|
||||
height,
|
||||
width,
|
||||
),
|
||||
device="cpu",
|
||||
device=self.unet.device,
|
||||
)
|
||||
noise = apply_noise_offset(noise, noise_offset)
|
||||
return noise
|
||||
@@ -784,6 +813,10 @@ class StableDiffusion:
|
||||
save_dtype=save_dtype,
|
||||
sd_version=version_string,
|
||||
)
|
||||
if self.config_file is not None:
|
||||
output_path_no_ext = os.path.splitext(output_file)[0]
|
||||
output_config_path = f"{output_path_no_ext}.yaml"
|
||||
shutil.copyfile(self.config_file, output_config_path)
|
||||
|
||||
def prepare_optimizer_params(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user