Added sd1.5 and 2.1 do the diffusers pipeline flow

This commit is contained in:
Jaret Burkett
2023-07-27 12:34:48 -06:00
parent 596e57a6a6
commit b2e2e4bf47
3 changed files with 349 additions and 114 deletions

View File

@@ -3,17 +3,12 @@ import time
from collections import OrderedDict
import os
import diffusers
from safetensors import safe_open
from library import sdxl_train_util, sdxl_model_util
from toolkit.kohya_model_util import load_vae
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
from toolkit.paths import REPOS_ROOT
import sys
from toolkit.pipelines import CustomStableDiffusionXLPipeline
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
@@ -55,8 +50,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.model_config = ModelConfig(**self.get_conf('model', {}))
self.save_config = SaveConfig(**self.get_conf('save', {}))
self.sample_config = SampleConfig(**self.get_conf('sample', {}))
self.first_sample_config = SampleConfig(
**self.get_conf('first_sample', {})) if 'first_sample' in self.config else self.sample_config
first_sample_config = self.get_conf('first_sample', None)
if first_sample_config is not None:
self.has_first_sample_requested = True
self.first_sample_config = SampleConfig(**first_sample_config)
else:
self.has_first_sample_requested = False
self.first_sample_config = self.sample_config
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
self.optimizer = None
self.lr_scheduler = None
@@ -101,19 +101,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# self.sd.text_encoder.to(self.device_torch)
# self.sd.tokenizer.to(self.device_torch)
# TODO add clip skip
if self.sd.is_xl:
pipeline = self.sd.pipeline
else:
pipeline = StableDiffusionPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder,
tokenizer=self.sd.tokenizer,
scheduler=self.sd.noise_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
pipeline = self.sd.pipeline
# disable progress bar
pipeline.set_progress_bar_config(disable=True)
@@ -172,24 +160,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
torch.manual_seed(current_seed)
torch.cuda.manual_seed(current_seed)
if self.sd.is_xl:
img = pipeline(
prompt,
height=height,
width=width,
num_inference_steps=sample_config.sample_steps,
guidance_scale=sample_config.guidance_scale,
negative_prompt=neg,
).images[0]
else:
img = pipeline(
prompt,
height=height,
width=width,
num_inference_steps=sample_config.sample_steps,
guidance_scale=sample_config.guidance_scale,
negative_prompt=neg,
).images[0]
img = pipeline(
prompt=prompt,
prompt_2=prompt,
negative_prompt=neg,
negative_prompt_2=neg,
height=height,
width=width,
num_inference_steps=sample_config.sample_steps,
guidance_scale=sample_config.guidance_scale,
).images[0]
step_num = ''
if step is not None:
@@ -202,9 +182,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
output_path = os.path.join(sample_folder, filename)
img.save(output_path)
# clear pipeline and cache to reduce vram usage
if not self.sd.is_xl:
del pipeline
torch.cuda.empty_cache()
# restore training state
@@ -230,9 +207,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
})
if self.model_config.is_v2:
dict['ss_v2'] = True
dict['ss_base_model_version'] = 'sd_2.1'
if self.model_config.is_xl:
elif self.model_config.is_xl:
dict['ss_base_model_version'] = 'sdxl_1.0'
else:
dict['ss_base_model_version'] = 'sd_1.5'
dict['ss_output_name'] = self.job.name
@@ -313,7 +293,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
):
if height is None and pixel_height is None:
raise ValueError("height or pixel_height must be specified")
raise ValueError("height or pixel_height must be specified")
if width is None and pixel_width is None:
raise ValueError("width or pixel_width must be specified")
if height is None:
@@ -371,7 +350,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
):
pass
def predict_noise(
self,
latents: torch.FloatTensor,
@@ -386,17 +364,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if add_time_ids is None:
add_time_ids = self.get_time_ids_from_latents(latents)
# todo LECOs code looks like it is omitting noise_pred
# noise_pred = train_util.predict_noise_xl(
# self.sd.unet,
# self.sd.noise_scheduler,
# timestep,
# latents,
# text_embeddings.text_embeds,
# text_embeddings.pooled_embeds,
# add_time_ids,
# guidance_scale=guidance_scale,
# guidance_rescale=guidance_rescale
# )
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep)
@@ -499,64 +467,66 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
# do our own scheduler
scheduler = KDPM2DiscreteScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.0120,
beta_schedule="scaled_linear",
)
if self.model_config.is_xl:
# do our own scheduler
scheduler = KDPM2DiscreteScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.0120,
beta_schedule="scaled_linear",
)
pipe = CustomStableDiffusionXLPipeline.from_single_file(
self.model_config.name_or_path,
dtype=dtype,
scheduler_type='dpm',
device=self.device_torch,
).to(self.device_torch)
pipe.scheduler = scheduler
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
unet = pipe.unet
noise_scheduler = pipe.scheduler
vae = pipe.vae.to('cpu', dtype=dtype)
vae.eval()
vae.set_use_memory_efficient_attention_xformers(True)
for text_encoder in text_encoders:
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
text_encoder = text_encoders
tokenizer = tokenizer
flush()
else:
tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models(
pipe = CustomStableDiffusionPipeline.from_single_file(
self.model_config.name_or_path,
scheduler_name=self.train_config.noise_scheduler,
v2=self.model_config.is_v2,
v_pred=self.model_config.is_v_pred,
)
dtype=dtype,
scheduler_type='dpm',
device=self.device_torch,
load_safety_checker=False,
).to(self.device_torch)
pipe.register_to_config(requires_safety_checker=False)
text_encoder = pipe.text_encoder
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
vae.eval()
pipe = None
tokenizer = pipe.tokenizer
# scheduler doesn't get set sometimes, so we set it here
pipe.scheduler = scheduler
unet = pipe.unet
noise_scheduler = pipe.scheduler
vae = pipe.vae.to('cpu', dtype=dtype)
vae.eval()
vae.requires_grad_(False)
flush()
# just for now or of we want to load a custom one
# put on cpu for now, we only need it when sampling
# vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
# vae.eval()
self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler, is_xl=self.model_config.is_xl, pipeline=pipe)
self.sd = StableDiffusion(
vae,
tokenizer,
text_encoder,
unet,
noise_scheduler,
is_xl=self.model_config.is_xl,
pipeline=pipe
)
unet.to(self.device_torch, dtype=dtype)
if self.train_config.xformers:
vae.set_use_memory_efficient_attention_xformers(True)
unet.enable_xformers_memory_efficient_attention()
if self.train_config.gradient_checkpointing:
unet.enable_gradient_checkpointing()
@@ -602,19 +572,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.network.multiplier = 1.0
else:
params = []
# assume dreambooth/finetune
if self.train_config.train_text_encoder:
text_encoder.requires_grad_(True)
text_encoder.train()
params += text_encoder.parameters()
if self.sd.is_xl:
for te in text_encoder:
te.requires_grad_(True)
te.train()
params += te.parameters()
else:
text_encoder.requires_grad_(True)
text_encoder.train()
params += text_encoder.parameters()
if self.train_config.train_unet:
unet.requires_grad_(True)
unet.train()
params += unet.parameters()
# TODO recover save if training network. Maybe load from beginning
### HOOK ###
params = self.hook_add_extra_train_params(params)
@@ -635,12 +612,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
### HOOK ###
self.hook_before_train_loop()
if self.has_first_sample_requested:
self.print("Generating first sample from first sample config")
self.sample(0, is_first=False)
# sample first
if self.train_config.skip_first_sample:
self.print("Skipping first sample due to config setting")
else:
self.print("Generating baseline samples before training")
self.sample(0, is_first=True)
self.sample(0)
self.progress_bar = tqdm(
total=self.train_config.steps,

View File

@@ -129,9 +129,12 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
print(f"Saving prompt tensors to {self.rescale_config.prompt_tensors}")
state_dict = {}
for prompt_txt, prompt_embeds in cache.prompts.items():
state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to("cpu", dtype=get_torch_dtype('fp16'))
state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to("cpu",
dtype=get_torch_dtype('fp16'))
if prompt_embeds.pooled_embeds is not None:
state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to("cpu", dtype=get_torch_dtype('fp16'))
state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to("cpu",
dtype=get_torch_dtype(
'fp16'))
save_file(state_dict, self.rescale_config.prompt_tensors)
self.print("Encoding complete.")
@@ -158,10 +161,15 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
]
prompt = self.prompt_cache[prompt_txt].to(device=self.device_torch, dtype=dtype)
prompt.text_embeds.to(device=self.device_torch, dtype=dtype)
prompt.pooled_embeds.to(device=self.device_torch, dtype=dtype)
neutral = self.prompt_cache[""].to(device=self.device_torch, dtype=dtype)
neutral.text_embeds.to(device=self.device_torch, dtype=dtype)
neutral.pooled_embeds.to(device=self.device_torch, dtype=dtype)
if hasattr(prompt, 'pooled_embeds') \
and hasattr(neutral, 'pooled_embeds') \
and prompt.pooled_embeds is not None \
and neutral.pooled_embeds is not None:
prompt.pooled_embeds.to(device=self.device_torch, dtype=dtype)
neutral.pooled_embeds.to(device=self.device_torch, dtype=dtype)
if prompt is None:
raise ValueError(f"Prompt {prompt_txt} is not in cache")

View File

@@ -1,7 +1,8 @@
from typing import Union, List, Optional, Dict, Any, Tuple, Callable
import torch
from diffusers import StableDiffusionXLPipeline
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
@@ -13,10 +14,7 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
@@ -28,16 +26,9 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
# predict_noise: bool = False,
timestep: Optional[int] = None,
):
r"""
@@ -226,8 +217,6 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
@@ -245,7 +234,7 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
add_text_embeds = pooled_prompt_embeds
add_time_ids = self._get_add_time_ids(
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
).to(device) # TODO DOES NOT CAST ORIGINALLY
).to(device) # TODO DOES NOT CAST ORIGINALLY
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
@@ -286,3 +275,260 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
print('Called cpu offload', gpu_id)
# fuck off
pass
class CustomStableDiffusionPipeline(StableDiffusionPipeline):
# replace the call so it matches SDXL call so we can use the same code and also stop early
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# 7.1 Apply denoising_end
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
# some of the inputs are to keep it compatible with sdx
def predict_noise(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
crops_coords_top_left: Tuple[int, int] = (0, 0),
timestep: Optional[int] = None,
):
# 0. Default height and width to unet
height = self.unet.config.sample_size * self.vae_scale_factor
width = self.unet.config.sample_size * self.vae_scale_factor
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
timestep,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
return noise_pred