diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 530f18e..0f8c23a 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1,3 +1,4 @@ +import copy import glob from collections import OrderedDict import os @@ -11,6 +12,7 @@ from toolkit.embedding import Embedding from toolkit.lora_special import LoRASpecialNetwork from toolkit.optimizer import get_optimizer from toolkit.paths import CONFIG_ROOT +from toolkit.sampler import get_sampler from toolkit.scheduler import get_lr_scheduler from toolkit.stable_diffusion_model import StableDiffusion @@ -89,6 +91,8 @@ class BaseSDTrainProcess(BaseTrainProcess): if embedding_raw is not None: self.embed_config = EmbeddingConfig(**embedding_raw) + model_config_to_load = copy.deepcopy(self.model_config) + if self.embed_config is None and self.network_config is None: # get the latest checkpoint # check to see if we have a latest save @@ -96,7 +100,7 @@ class BaseSDTrainProcess(BaseTrainProcess): if latest_save_path is not None: print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") - self.model_config.name_or_path = latest_save_path + model_config_to_load.name_or_path = latest_save_path meta = load_metadata_from_safetensors(latest_save_path) # if 'training_info' in Orderdict keys if 'training_info' in meta and 'step' in meta['training_info']: @@ -104,11 +108,15 @@ class BaseSDTrainProcess(BaseTrainProcess): self.start_step = self.step_num print(f"Found step {self.step_num} in metadata, starting from there") + # get the noise scheduler + sampler = get_sampler(self.train_config.noise_scheduler) + self.sd = StableDiffusion( device=self.device, - model_config=self.model_config, + model_config=model_config_to_load, dtype=self.train_config.dtype, custom_pipeline=self.custom_pipeline, + noise_scheduler=sampler, ) # to hold network if there is one @@ -164,7 +172,7 @@ class BaseSDTrainProcess(BaseTrainProcess): )) # send to be generated - self.sd.generate_images(gen_img_config_list) + self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler) def update_training_metadata(self): o_dict = OrderedDict({ @@ -216,10 +224,18 @@ class BaseSDTrainProcess(BaseTrainProcess): for file in files[:-self.save_config.max_step_saves_to_keep]: self.print(f"Removing old save: {file}") os.remove(file) + # see if a yaml file with same name exists + yaml_file = os.path.splitext(file)[0] + ".yaml" + if os.path.exists(yaml_file): + os.remove(yaml_file) return latest_file else: return None + def post_save_hook(self, save_path): + # override in subclass + pass + def save(self, step=None): if not os.path.exists(self.save_root): os.makedirs(self.save_root, exist_ok=True) @@ -263,6 +279,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.print(f"Saved to {file_path}") self.clean_up_saves() + self.post_save_hook(file_path) # Called before the model is loaded def hook_before_model_load(self): @@ -279,6 +296,11 @@ class BaseSDTrainProcess(BaseTrainProcess): def before_dataset_load(self): pass + def get_params(self): + # you can extend this in subclass to get params + # otherwise params will be gathered through normal means + return None + def hook_train_loop(self, batch): # return loss return 0.0 @@ -445,11 +467,14 @@ class BaseSDTrainProcess(BaseTrainProcess): self.network.prepare_grad_etc(text_encoder, unet) - params = self.network.prepare_optimizer_params( - text_encoder_lr=self.train_config.lr, - unet_lr=self.train_config.lr, - default_lr=self.train_config.lr - ) + params = self.get_params() + + if not params: + params = self.network.prepare_optimizer_params( + text_encoder_lr=self.train_config.lr, + unet_lr=self.train_config.lr, + default_lr=self.train_config.lr + ) if self.train_config.gradient_checkpointing: self.network.enable_gradient_checkpointing() @@ -477,8 +502,10 @@ class BaseSDTrainProcess(BaseTrainProcess): self.step_num = self.embedding.step self.start_step = self.step_num - # set trainable params - params = self.embedding.get_trainable_params() + params = self.get_params() + if not params: + # set trainable params + params = self.embedding.get_trainable_params() else: # set them to train or not @@ -506,14 +533,17 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sd.text_encoder.requires_grad_(False) self.sd.text_encoder.eval() - # will only return savable weights and ones with grad - params = self.sd.prepare_optimizer_params( - unet=self.train_config.train_unet, - text_encoder=self.train_config.train_text_encoder, - text_encoder_lr=self.train_config.lr, - unet_lr=self.train_config.lr, - default_lr=self.train_config.lr - ) + params = self.get_params() + + if params is None: + # will only return savable weights and ones with grad + params = self.sd.prepare_optimizer_params( + unet=self.train_config.train_unet, + text_encoder=self.train_config.train_text_encoder, + text_encoder_lr=self.train_config.lr, + unet_lr=self.train_config.lr, + default_lr=self.train_config.lr + ) ### HOOK ### params = self.hook_add_extra_train_params(params) diff --git a/jobs/process/GenerateProcess.py b/jobs/process/GenerateProcess.py index a005ae0..13acca2 100644 --- a/jobs/process/GenerateProcess.py +++ b/jobs/process/GenerateProcess.py @@ -98,7 +98,7 @@ class GenerateProcess(BaseProcess): add_prompt_file=self.generate_config.prompt_file )) # generate images - self.sd.generate_images(prompt_image_configs) + self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler) print("Done generating images") # cleanup diff --git a/requirements.txt b/requirements.txt index a11b621..2895366 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,4 +15,5 @@ accelerate toml albumentations pydantic -omegaconf \ No newline at end of file +omegaconf +k-diffusion \ No newline at end of file diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index ac612f8..b1e6cd8 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -22,6 +22,7 @@ class LogingConfig: class SampleConfig: def __init__(self, **kwargs): + self.sampler: str = kwargs.get('sampler', 'ddpm') self.sample_every: int = kwargs.get('sample_every', 100) self.width: int = kwargs.get('width', 512) self.height: int = kwargs.get('height', 512) diff --git a/toolkit/llvae.py b/toolkit/llvae.py index aedbcb6..f8ed8f5 100644 --- a/toolkit/llvae.py +++ b/toolkit/llvae.py @@ -35,6 +35,10 @@ class LosslessLatentDecoder(nn.Module): return kernel def forward(self, x): + dtype = x.dtype + if self.kernel.dtype != dtype: + self.kernel = self.kernel.to(dtype=dtype) + # Deconvolve input tensor with the kernel return nn.functional.conv_transpose2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1) @@ -70,6 +74,9 @@ class LosslessLatentEncoder(nn.Module): return kernel def forward(self, x): + dtype = x.dtype + if self.kernel.dtype != dtype: + self.kernel = self.kernel.to(dtype=dtype) # Convolve input tensor with the kernel return nn.functional.conv2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1) diff --git a/toolkit/pipelines.py b/toolkit/pipelines.py index f772fa1..83ca0a8 100644 --- a/toolkit/pipelines.py +++ b/toolkit/pipelines.py @@ -1,14 +1,289 @@ +import importlib +import inspect from typing import Union, List, Optional, Dict, Any, Tuple, Callable import torch -from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline +from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg +from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser +from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler + + +class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline): + + def __init__( + self, + vae: 'AutoencoderKL', + text_encoder: 'CLIPTextModel', + text_encoder_2: 'CLIPTextModelWithProjection', + tokenizer: 'CLIPTokenizer', + tokenizer_2: 'CLIPTokenizer', + unet: 'UNet2DConditionModel', + scheduler: 'KarrasDiffusionSchedulers', + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.sampler = None + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + model = ModelWrapper(unet, scheduler.alphas_cumprod) + if scheduler.config.prediction_type == "v_prediction": + self.k_diffusion_model = CompVisVDenoiser(model) + else: + self.k_diffusion_model = CompVisDenoiser(model) + + def set_scheduler(self, scheduler_type: str): + library = importlib.import_module("k_diffusion") + sampling = getattr(library, "sampling") + self.sampler = getattr(sampling, scheduler_type) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + use_karras_sigmas: bool = False, + ): + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 5. Prepare sigmas + if use_karras_sigmas: + sigma_min: float = self.k_diffusion_model.sigmas[0].item() + sigma_max: float = self.k_diffusion_model.sigmas[-1].item() + sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max) + sigmas = sigmas.to(device) + else: + sigmas = self.scheduler.sigmas + sigmas = sigmas.to(prompt_embeds.dtype) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + latents = latents * sigmas[0] + self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) + self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) + + # 7. Define model function + def model_fn(x, t): + latent_model_input = torch.cat([x] * 2) + t = torch.cat([t] * 2) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + # noise_pred = self.unet( + # latent_model_input, + # t, + # encoder_hidden_states=prompt_embeds, + # cross_attention_kwargs=cross_attention_kwargs, + # added_cond_kwargs=added_cond_kwargs, + # return_dict=False, + # )[0] + + noise_pred = self.k_diffusion_model( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False,)[0] + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + return noise_pred + + + # 8. Run k-diffusion solver + sampler_kwargs = {} + # should work without it + noise_sampler_seed = None + + + if "noise_sampler" in inspect.signature(self.sampler).parameters: + min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed) + sampler_kwargs["noise_sampler"] = noise_sampler + + latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): - # def __init__(self, *args, **kwargs): - # super().__init__(*args, **kwargs) def predict_noise( self, diff --git a/toolkit/sampler.py b/toolkit/sampler.py new file mode 100644 index 0000000..098f452 --- /dev/null +++ b/toolkit/sampler.py @@ -0,0 +1,107 @@ +import copy + +from diffusers import ( + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, +) +from k_diffusion.external import CompVisDenoiser + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + +sdxl_sampler_config = { + "_class_name": "EulerDiscreteScheduler", + "_diffusers_version": "0.19.0.dev0", + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": False, + "interpolation_type": "linear", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", + "trained_betas": None, + "use_karras_sigmas": False +} + + +def get_sampler( + sampler: str, +): + sched_init_args = {} + + if sampler.startswith("k_"): + sched_init_args["use_karras_sigmas"] = True + + if sampler == "ddim": + scheduler_cls = DDIMScheduler + elif sampler == "ddpm": # ddpm is not supported ? + scheduler_cls = DDPMScheduler + elif sampler == "pndm": + scheduler_cls = PNDMScheduler + elif sampler == "lms" or sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + elif sampler == "euler" or sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + elif sampler == "euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + elif sampler == "dpmsolver" or sampler == "dpmsolver++" or sampler == "k_dpmsolver" or sampler == "k_dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = sampler.replace("k_", "") + elif sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + elif sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + elif sampler == "dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + elif sampler == "dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + + config = copy.deepcopy(sdxl_sampler_config) + config.update(sched_init_args) + + scheduler = scheduler_cls.from_config(config) + + return scheduler + + +# testing +if __name__ == "__main__": + from diffusers import DiffusionPipeline + + from diffusers import StableDiffusionKDiffusionPipeline + import torch + import os + + inference_steps = 25 + + pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base") + pipe = pipe.to("cuda") + + k_diffusion_model = CompVisDenoiser(model) + + pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion") + pipe = pipe.to("cuda") + + prompt = "an astronaut riding a horse on mars" + pipe.set_scheduler("sample_heun") + generator = torch.Generator(device="cuda").manual_seed(seed) + image = pipe(prompt, generator=generator, num_inference_steps=20).images[0] + + image.save("./astronaut_heun_k_diffusion.png") diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index cbfc48c..878fca1 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1,5 +1,6 @@ import gc import json +import shutil import typing from typing import Union, List, Tuple, Iterator import sys @@ -19,13 +20,15 @@ from toolkit.config_modules import ModelConfig, GenerateImageConfig from toolkit.metadata import get_meta_for_safetensors from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds +from toolkit.sampler import get_sampler from toolkit.saving import save_ldm_model_from_diffusers from toolkit.train_tools import get_torch_dtype, apply_noise_offset import torch from library import model_util from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl from diffusers.schedulers import DDPMScheduler -from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline +from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ + StableDiffusionKDiffusionXLPipeline from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline import diffusers @@ -84,24 +87,14 @@ if typing.TYPE_CHECKING: class StableDiffusion: - pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline'] - vae: Union[None, 'AutoencoderKL'] - unet: Union[None, 'UNet2DConditionModel'] - text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] - tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] - noise_scheduler: Union[None, 'KarrasDiffusionSchedulers', 'DDPMScheduler'] - device: str - dtype: str - torch_dtype: torch.dtype - device_torch: torch.device - model_config: ModelConfig def __init__( self, device, model_config: ModelConfig, dtype='fp16', - custom_pipeline=None + custom_pipeline=None, + noise_scheduler=None, ): self.custom_pipeline = custom_pipeline self.device = device @@ -111,6 +104,13 @@ class StableDiffusion: self.model_config = model_config self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" + self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline'] + self.vae: Union[None, 'AutoencoderKL'] + self.unet: Union[None, 'UNet2DConditionModel'] + self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] + self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] + self.noise_scheduler: Union[None, 'KarrasDiffusionSchedulers', 'DDPMScheduler'] = noise_scheduler + # sdxl stuff self.logit_scale = None self.ckppt_info = None @@ -124,6 +124,8 @@ class StableDiffusion: self.use_text_encoder_1 = model_config.use_text_encoder_1 self.use_text_encoder_2 = model_config.use_text_encoder_2 + self.config_file = None + def load_model(self): if self.is_loaded: return @@ -131,23 +133,25 @@ class StableDiffusion: # TODO handle other schedulers # sch = KDPM2DiscreteScheduler - sch = DDPMScheduler - # do our own scheduler - prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" - scheduler = sch( - num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.0120, - beta_schedule="scaled_linear", - clip_sample=False, - prediction_type=prediction_type, - steps_offset=1 - ) + if self.noise_scheduler is None: + sch = DDPMScheduler + # do our own scheduler + prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" + scheduler = sch( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.0120, + beta_schedule="scaled_linear", + clip_sample=False, + prediction_type=prediction_type, + steps_offset=0 + ) + self.noise_scheduler = scheduler # move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why - scheduler.betas = scheduler.betas.to(self.device_torch) - scheduler.alphas = scheduler.alphas.to(self.device_torch) - scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(self.device_torch) + self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch) + self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch) + self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch) model_path = self.model_config.name_or_path if 'civitai.com' in self.model_config.name_or_path: @@ -159,7 +163,8 @@ class StableDiffusion: if self.custom_pipeline is not None: pipln = self.custom_pipeline else: - pipln = CustomStableDiffusionXLPipeline + pipln = StableDiffusionXLPipeline + # pipln = StableDiffusionKDiffusionXLPipeline # see if path exists if not os.path.exists(model_path): @@ -204,7 +209,7 @@ class StableDiffusion: if self.custom_pipeline is not None: pipln = self.custom_pipeline else: - pipln = CustomStableDiffusionPipeline + pipln = StableDiffusionPipeline # see if path exists if not os.path.exists(model_path): @@ -237,14 +242,13 @@ class StableDiffusion: tokenizer = pipe.tokenizer # scheduler doesn't get set sometimes, so we set it here - pipe.scheduler = scheduler + pipe.scheduler = self.noise_scheduler if self.model_config.vae_path is not None: external_vae = load_vae(self.model_config.vae_path, dtype) pipe.vae = external_vae self.unet = pipe.unet - self.noise_scheduler = pipe.scheduler self.vae = pipe.vae.to(self.device_torch, dtype=dtype) self.vae.eval() self.vae.requires_grad_(False) @@ -257,7 +261,7 @@ class StableDiffusion: self.pipeline = pipe self.is_loaded = True - def generate_images(self, image_configs: List[GenerateImageConfig]): + def generate_images(self, image_configs: List[GenerateImageConfig], sampler=None): # sample_folder = os.path.join(self.save_root, 'samples') if self.network is not None: self.network.eval() @@ -293,16 +297,31 @@ class StableDiffusion: self.vae.to(self.device_torch) self.unet.to(self.device_torch) + noise_scheduler = self.noise_scheduler + if sampler is not None: + if sampler.startswith("sample_"): # sample_dpmpp_2m + # using ksampler + noise_scheduler = get_sampler('lms') + else: + noise_scheduler = get_sampler(sampler) + + if sampler.startswith("sample_") and self.is_xl: + # using kdiffusion + Pipe = StableDiffusionKDiffusionXLPipeline + else: + Pipe = StableDiffusionXLPipeline + + # TODO add clip skip if self.is_xl: - pipeline = StableDiffusionXLPipeline( + pipeline = Pipe( vae=self.vae, unet=self.unet, text_encoder=self.text_encoder[0], text_encoder_2=self.text_encoder[1], tokenizer=self.tokenizer[0], tokenizer_2=self.tokenizer[1], - scheduler=self.noise_scheduler, + scheduler=noise_scheduler, add_watermarker=False, ).to(self.device_torch) # force turn that (ruin your images with obvious green and red dots) the #$@@ off!!! @@ -313,7 +332,7 @@ class StableDiffusion: unet=self.unet, text_encoder=self.text_encoder, tokenizer=self.tokenizer, - scheduler=self.noise_scheduler, + scheduler=noise_scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False, @@ -321,6 +340,9 @@ class StableDiffusion: # disable progress bar pipeline.set_progress_bar_config(disable=True) + if sampler.startswith("sample_"): + pipeline.set_scheduler(sampler) + start_multiplier = 1.0 if self.network is not None: start_multiplier = self.network.multiplier @@ -345,8 +367,14 @@ class StableDiffusion: # was trained on 0.7 (I believe) grs = gen_config.guidance_rescale - if grs is None or grs < 0.00001: - grs = 0.7 + # if grs is None or grs < 0.00001: + # grs = 0.7 + grs = 0.0 + + extra = {} + if sampler.startswith("sample_"): + extra['use_karras_sigmas'] = True + img = pipeline( prompt=gen_config.prompt, @@ -358,6 +386,7 @@ class StableDiffusion: num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, guidance_rescale=grs, + **extra ).images[0] else: img = pipeline( @@ -414,11 +443,11 @@ class StableDiffusion: noise = torch.randn( ( batch_size, - UNET_IN_CHANNELS, + self.unet.config['in_channels'], height, width, ), - device="cpu", + device=self.unet.device, ) noise = apply_noise_offset(noise, noise_offset) return noise @@ -784,6 +813,10 @@ class StableDiffusion: save_dtype=save_dtype, sd_version=version_string, ) + if self.config_file is not None: + output_path_no_ext = os.path.splitext(output_file)[0] + output_config_path = f"{output_path_no_ext}.yaml" + shutil.copyfile(self.config_file, output_config_path) def prepare_optimizer_params( self,