reworked samplers. Trying to find what is wrong with diffusers sampling is sdxl

This commit is contained in:
Jaret Burkett
2023-09-03 07:56:09 -06:00
parent 4ca819a05e
commit 2a40937b4f
8 changed files with 517 additions and 63 deletions

View File

@@ -1,3 +1,4 @@
import copy
import glob
from collections import OrderedDict
import os
@@ -11,6 +12,7 @@ from toolkit.embedding import Embedding
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
from toolkit.paths import CONFIG_ROOT
from toolkit.sampler import get_sampler
from toolkit.scheduler import get_lr_scheduler
from toolkit.stable_diffusion_model import StableDiffusion
@@ -89,6 +91,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
if embedding_raw is not None:
self.embed_config = EmbeddingConfig(**embedding_raw)
model_config_to_load = copy.deepcopy(self.model_config)
if self.embed_config is None and self.network_config is None:
# get the latest checkpoint
# check to see if we have a latest save
@@ -96,7 +100,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
self.model_config.name_or_path = latest_save_path
model_config_to_load.name_or_path = latest_save_path
meta = load_metadata_from_safetensors(latest_save_path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
@@ -104,11 +108,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
# get the noise scheduler
sampler = get_sampler(self.train_config.noise_scheduler)
self.sd = StableDiffusion(
device=self.device,
model_config=self.model_config,
model_config=model_config_to_load,
dtype=self.train_config.dtype,
custom_pipeline=self.custom_pipeline,
noise_scheduler=sampler,
)
# to hold network if there is one
@@ -164,7 +172,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
))
# send to be generated
self.sd.generate_images(gen_img_config_list)
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
def update_training_metadata(self):
o_dict = OrderedDict({
@@ -216,10 +224,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
for file in files[:-self.save_config.max_step_saves_to_keep]:
self.print(f"Removing old save: {file}")
os.remove(file)
# see if a yaml file with same name exists
yaml_file = os.path.splitext(file)[0] + ".yaml"
if os.path.exists(yaml_file):
os.remove(yaml_file)
return latest_file
else:
return None
def post_save_hook(self, save_path):
# override in subclass
pass
def save(self, step=None):
if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True)
@@ -263,6 +279,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.print(f"Saved to {file_path}")
self.clean_up_saves()
self.post_save_hook(file_path)
# Called before the model is loaded
def hook_before_model_load(self):
@@ -279,6 +296,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
def before_dataset_load(self):
pass
def get_params(self):
# you can extend this in subclass to get params
# otherwise params will be gathered through normal means
return None
def hook_train_loop(self, batch):
# return loss
return 0.0
@@ -445,11 +467,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.network.prepare_grad_etc(text_encoder, unet)
params = self.network.prepare_optimizer_params(
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
params = self.get_params()
if not params:
params = self.network.prepare_optimizer_params(
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
if self.train_config.gradient_checkpointing:
self.network.enable_gradient_checkpointing()
@@ -477,8 +502,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.step_num = self.embedding.step
self.start_step = self.step_num
# set trainable params
params = self.embedding.get_trainable_params()
params = self.get_params()
if not params:
# set trainable params
params = self.embedding.get_trainable_params()
else:
# set them to train or not
@@ -506,14 +533,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.text_encoder.requires_grad_(False)
self.sd.text_encoder.eval()
# will only return savable weights and ones with grad
params = self.sd.prepare_optimizer_params(
unet=self.train_config.train_unet,
text_encoder=self.train_config.train_text_encoder,
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
params = self.get_params()
if params is None:
# will only return savable weights and ones with grad
params = self.sd.prepare_optimizer_params(
unet=self.train_config.train_unet,
text_encoder=self.train_config.train_text_encoder,
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
### HOOK ###
params = self.hook_add_extra_train_params(params)

View File

@@ -98,7 +98,7 @@ class GenerateProcess(BaseProcess):
add_prompt_file=self.generate_config.prompt_file
))
# generate images
self.sd.generate_images(prompt_image_configs)
self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler)
print("Done generating images")
# cleanup

View File

@@ -15,4 +15,5 @@ accelerate
toml
albumentations
pydantic
omegaconf
omegaconf
k-diffusion

View File

@@ -22,6 +22,7 @@ class LogingConfig:
class SampleConfig:
def __init__(self, **kwargs):
self.sampler: str = kwargs.get('sampler', 'ddpm')
self.sample_every: int = kwargs.get('sample_every', 100)
self.width: int = kwargs.get('width', 512)
self.height: int = kwargs.get('height', 512)

View File

@@ -35,6 +35,10 @@ class LosslessLatentDecoder(nn.Module):
return kernel
def forward(self, x):
dtype = x.dtype
if self.kernel.dtype != dtype:
self.kernel = self.kernel.to(dtype=dtype)
# Deconvolve input tensor with the kernel
return nn.functional.conv_transpose2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1)
@@ -70,6 +74,9 @@ class LosslessLatentEncoder(nn.Module):
return kernel
def forward(self, x):
dtype = x.dtype
if self.kernel.dtype != dtype:
self.kernel = self.kernel.to(dtype=dtype)
# Convolve input tensor with the kernel
return nn.functional.conv2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1)

View File

@@ -1,14 +1,289 @@
import importlib
import inspect
from typing import Union, List, Optional, Dict, Any, Tuple, Callable
import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline):
def __init__(
self,
vae: 'AutoencoderKL',
text_encoder: 'CLIPTextModel',
text_encoder_2: 'CLIPTextModelWithProjection',
tokenizer: 'CLIPTokenizer',
tokenizer_2: 'CLIPTokenizer',
unet: 'UNet2DConditionModel',
scheduler: 'KarrasDiffusionSchedulers',
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
):
super().__init__(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
)
self.sampler = None
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
model = ModelWrapper(unet, scheduler.alphas_cumprod)
if scheduler.config.prediction_type == "v_prediction":
self.k_diffusion_model = CompVisVDenoiser(model)
else:
self.k_diffusion_model = CompVisDenoiser(model)
def set_scheduler(self, scheduler_type: str):
library = importlib.import_module("k_diffusion")
sampling = getattr(library, "sampling")
self.sampler = getattr(sampling, scheduler_type)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
use_karras_sigmas: bool = False,
):
# 0. Default height and width to unet
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
add_time_ids = self._get_add_time_ids(
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 7.1 Apply denoising_end
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
# 5. Prepare sigmas
if use_karras_sigmas:
sigma_min: float = self.k_diffusion_model.sigmas[0].item()
sigma_max: float = self.k_diffusion_model.sigmas[-1].item()
sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max)
sigmas = sigmas.to(device)
else:
sigmas = self.scheduler.sigmas
sigmas = sigmas.to(prompt_embeds.dtype)
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
latents = latents * sigmas[0]
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
# 7. Define model function
def model_fn(x, t):
latent_model_input = torch.cat([x] * 2)
t = torch.cat([t] * 2)
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# noise_pred = self.unet(
# latent_model_input,
# t,
# encoder_hidden_states=prompt_embeds,
# cross_attention_kwargs=cross_attention_kwargs,
# added_cond_kwargs=added_cond_kwargs,
# return_dict=False,
# )[0]
noise_pred = self.k_diffusion_model(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,)[0]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
return noise_pred
# 8. Run k-diffusion solver
sampler_kwargs = {}
# should work without it
noise_sampler_seed = None
if "noise_sampler" in inspect.signature(self.sampler).parameters:
min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
sampler_kwargs["noise_sampler"] = noise_sampler
latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (image,)
return StableDiffusionXLPipelineOutput(images=image)
class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
def predict_noise(
self,

107
toolkit/sampler.py Normal file
View File

@@ -0,0 +1,107 @@
import copy
from diffusers import (
DDPMScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
)
from k_diffusion.external import CompVisDenoiser
# scheduler:
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear"
sdxl_sampler_config = {
"_class_name": "EulerDiscreteScheduler",
"_diffusers_version": "0.19.0.dev0",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": False,
"interpolation_type": "linear",
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"sample_max_value": 1.0,
"set_alpha_to_one": False,
"skip_prk_steps": True,
"steps_offset": 1,
"timestep_spacing": "leading",
"trained_betas": None,
"use_karras_sigmas": False
}
def get_sampler(
sampler: str,
):
sched_init_args = {}
if sampler.startswith("k_"):
sched_init_args["use_karras_sigmas"] = True
if sampler == "ddim":
scheduler_cls = DDIMScheduler
elif sampler == "ddpm": # ddpm is not supported ?
scheduler_cls = DDPMScheduler
elif sampler == "pndm":
scheduler_cls = PNDMScheduler
elif sampler == "lms" or sampler == "k_lms":
scheduler_cls = LMSDiscreteScheduler
elif sampler == "euler" or sampler == "k_euler":
scheduler_cls = EulerDiscreteScheduler
elif sampler == "euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler
elif sampler == "dpmsolver" or sampler == "dpmsolver++" or sampler == "k_dpmsolver" or sampler == "k_dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = sampler.replace("k_", "")
elif sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler
elif sampler == "heun":
scheduler_cls = HeunDiscreteScheduler
elif sampler == "dpm_2":
scheduler_cls = KDPM2DiscreteScheduler
elif sampler == "dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
config = copy.deepcopy(sdxl_sampler_config)
config.update(sched_init_args)
scheduler = scheduler_cls.from_config(config)
return scheduler
# testing
if __name__ == "__main__":
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionKDiffusionPipeline
import torch
import os
inference_steps = 25
pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
pipe = pipe.to("cuda")
k_diffusion_model = CompVisDenoiser(model)
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
pipe = pipe.to("cuda")
prompt = "an astronaut riding a horse on mars"
pipe.set_scheduler("sample_heun")
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
image.save("./astronaut_heun_k_diffusion.png")

View File

@@ -1,5 +1,6 @@
import gc
import json
import shutil
import typing
from typing import Union, List, Tuple, Iterator
import sys
@@ -19,13 +20,15 @@ from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
from toolkit.sampler import get_sampler
from toolkit.saving import save_ldm_model_from_diffusers
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import torch
from library import model_util
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
from diffusers.schedulers import DDPMScheduler
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
StableDiffusionKDiffusionXLPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
import diffusers
@@ -84,24 +87,14 @@ if typing.TYPE_CHECKING:
class StableDiffusion:
pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline']
vae: Union[None, 'AutoencoderKL']
unet: Union[None, 'UNet2DConditionModel']
text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
noise_scheduler: Union[None, 'KarrasDiffusionSchedulers', 'DDPMScheduler']
device: str
dtype: str
torch_dtype: torch.dtype
device_torch: torch.device
model_config: ModelConfig
def __init__(
self,
device,
model_config: ModelConfig,
dtype='fp16',
custom_pipeline=None
custom_pipeline=None,
noise_scheduler=None,
):
self.custom_pipeline = custom_pipeline
self.device = device
@@ -111,6 +104,13 @@ class StableDiffusion:
self.model_config = model_config
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline']
self.vae: Union[None, 'AutoencoderKL']
self.unet: Union[None, 'UNet2DConditionModel']
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
self.noise_scheduler: Union[None, 'KarrasDiffusionSchedulers', 'DDPMScheduler'] = noise_scheduler
# sdxl stuff
self.logit_scale = None
self.ckppt_info = None
@@ -124,6 +124,8 @@ class StableDiffusion:
self.use_text_encoder_1 = model_config.use_text_encoder_1
self.use_text_encoder_2 = model_config.use_text_encoder_2
self.config_file = None
def load_model(self):
if self.is_loaded:
return
@@ -131,23 +133,25 @@ class StableDiffusion:
# TODO handle other schedulers
# sch = KDPM2DiscreteScheduler
sch = DDPMScheduler
# do our own scheduler
prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
scheduler = sch(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.0120,
beta_schedule="scaled_linear",
clip_sample=False,
prediction_type=prediction_type,
steps_offset=1
)
if self.noise_scheduler is None:
sch = DDPMScheduler
# do our own scheduler
prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
scheduler = sch(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.0120,
beta_schedule="scaled_linear",
clip_sample=False,
prediction_type=prediction_type,
steps_offset=0
)
self.noise_scheduler = scheduler
# move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
scheduler.betas = scheduler.betas.to(self.device_torch)
scheduler.alphas = scheduler.alphas.to(self.device_torch)
scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(self.device_torch)
self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
model_path = self.model_config.name_or_path
if 'civitai.com' in self.model_config.name_or_path:
@@ -159,7 +163,8 @@ class StableDiffusion:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
else:
pipln = CustomStableDiffusionXLPipeline
pipln = StableDiffusionXLPipeline
# pipln = StableDiffusionKDiffusionXLPipeline
# see if path exists
if not os.path.exists(model_path):
@@ -204,7 +209,7 @@ class StableDiffusion:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
else:
pipln = CustomStableDiffusionPipeline
pipln = StableDiffusionPipeline
# see if path exists
if not os.path.exists(model_path):
@@ -237,14 +242,13 @@ class StableDiffusion:
tokenizer = pipe.tokenizer
# scheduler doesn't get set sometimes, so we set it here
pipe.scheduler = scheduler
pipe.scheduler = self.noise_scheduler
if self.model_config.vae_path is not None:
external_vae = load_vae(self.model_config.vae_path, dtype)
pipe.vae = external_vae
self.unet = pipe.unet
self.noise_scheduler = pipe.scheduler
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
self.vae.eval()
self.vae.requires_grad_(False)
@@ -257,7 +261,7 @@ class StableDiffusion:
self.pipeline = pipe
self.is_loaded = True
def generate_images(self, image_configs: List[GenerateImageConfig]):
def generate_images(self, image_configs: List[GenerateImageConfig], sampler=None):
# sample_folder = os.path.join(self.save_root, 'samples')
if self.network is not None:
self.network.eval()
@@ -293,16 +297,31 @@ class StableDiffusion:
self.vae.to(self.device_torch)
self.unet.to(self.device_torch)
noise_scheduler = self.noise_scheduler
if sampler is not None:
if sampler.startswith("sample_"): # sample_dpmpp_2m
# using ksampler
noise_scheduler = get_sampler('lms')
else:
noise_scheduler = get_sampler(sampler)
if sampler.startswith("sample_") and self.is_xl:
# using kdiffusion
Pipe = StableDiffusionKDiffusionXLPipeline
else:
Pipe = StableDiffusionXLPipeline
# TODO add clip skip
if self.is_xl:
pipeline = StableDiffusionXLPipeline(
pipeline = Pipe(
vae=self.vae,
unet=self.unet,
text_encoder=self.text_encoder[0],
text_encoder_2=self.text_encoder[1],
tokenizer=self.tokenizer[0],
tokenizer_2=self.tokenizer[1],
scheduler=self.noise_scheduler,
scheduler=noise_scheduler,
add_watermarker=False,
).to(self.device_torch)
# force turn that (ruin your images with obvious green and red dots) the #$@@ off!!!
@@ -313,7 +332,7 @@ class StableDiffusion:
unet=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=self.noise_scheduler,
scheduler=noise_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
@@ -321,6 +340,9 @@ class StableDiffusion:
# disable progress bar
pipeline.set_progress_bar_config(disable=True)
if sampler.startswith("sample_"):
pipeline.set_scheduler(sampler)
start_multiplier = 1.0
if self.network is not None:
start_multiplier = self.network.multiplier
@@ -345,8 +367,14 @@ class StableDiffusion:
# was trained on 0.7 (I believe)
grs = gen_config.guidance_rescale
if grs is None or grs < 0.00001:
grs = 0.7
# if grs is None or grs < 0.00001:
# grs = 0.7
grs = 0.0
extra = {}
if sampler.startswith("sample_"):
extra['use_karras_sigmas'] = True
img = pipeline(
prompt=gen_config.prompt,
@@ -358,6 +386,7 @@ class StableDiffusion:
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
guidance_rescale=grs,
**extra
).images[0]
else:
img = pipeline(
@@ -414,11 +443,11 @@ class StableDiffusion:
noise = torch.randn(
(
batch_size,
UNET_IN_CHANNELS,
self.unet.config['in_channels'],
height,
width,
),
device="cpu",
device=self.unet.device,
)
noise = apply_noise_offset(noise, noise_offset)
return noise
@@ -784,6 +813,10 @@ class StableDiffusion:
save_dtype=save_dtype,
sd_version=version_string,
)
if self.config_file is not None:
output_path_no_ext = os.path.splitext(output_file)[0]
output_config_path = f"{output_path_no_ext}.yaml"
shutil.copyfile(self.config_file, output_config_path)
def prepare_optimizer_params(
self,