mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Pipelines working on SDXL for noise prediction
This commit is contained in:
@@ -18,7 +18,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2DiscreteScheduler
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
|
||||
@@ -500,13 +500,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
if self.model_config.is_xl:
|
||||
# do our own scheduler
|
||||
scheduler = KDPM2DiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.00085,
|
||||
beta_end=0.0120,
|
||||
beta_schedule="scaled_linear",
|
||||
)
|
||||
|
||||
pipe = CustomStableDiffusionXLPipeline.from_single_file(
|
||||
self.model_config.name_or_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='dpm',
|
||||
device=self.device_torch
|
||||
device=self.device_torch,
|
||||
).to(self.device_torch)
|
||||
pipe.scheduler = scheduler
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
unet = pipe.unet
|
||||
@@ -637,10 +645,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.progress_bar = tqdm(
|
||||
total=self.train_config.steps,
|
||||
desc=self.job.name,
|
||||
leave=True
|
||||
leave=True,
|
||||
initial=self.step_num,
|
||||
iterable=range(0, self.train_config.steps),
|
||||
)
|
||||
# set it to our current step in case it was updated from a load
|
||||
self.progress_bar.update(self.step_num)
|
||||
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
# todo handle dataloader here maybe, not sure
|
||||
|
||||
@@ -171,9 +171,9 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||
loss_function = torch.nn.MSELoss()
|
||||
|
||||
with torch.no_grad():
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
# self.sd.noise_scheduler.set_timesteps(
|
||||
# self.train_config.max_denoising_steps, device=self.device_torch
|
||||
# )
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
@@ -183,6 +183,12 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||
).item()
|
||||
absolute_total_timesteps = 1000
|
||||
|
||||
max_len_timestep_str = len(str(self.train_config.max_denoising_steps))
|
||||
# pad with spaces
|
||||
timestep_str = str(timesteps_to).rjust(max_len_timestep_str, " ")
|
||||
new_description = f"{self.job.name} ts: {timestep_str}"
|
||||
self.progress_bar.set_description(new_description)
|
||||
|
||||
# get noise
|
||||
latents = self.get_latent_noise(
|
||||
pixel_height=self.rescale_config.from_resolution,
|
||||
@@ -190,21 +196,37 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
denoised_fraction = timesteps_to / absolute_total_timesteps
|
||||
self.sd.pipeline.to(self.device_torch)
|
||||
torch.set_default_device(self.device_torch)
|
||||
|
||||
denoised_latents = self.sd.pipeline(
|
||||
num_inference_steps=1000,
|
||||
denoising_end=denoised_fraction,
|
||||
latents=latents,
|
||||
prompt_embeds=prompt.text_embeds,
|
||||
negative_prompt_embeds=neutral.text_embeds,
|
||||
pooled_prompt_embeds=prompt.pooled_embeds,
|
||||
negative_pooled_prompt_embeds=neutral.pooled_embeds,
|
||||
output_type="latent",
|
||||
num_images_per_prompt=self.train_config.batch_size,
|
||||
guidance_scale=3,
|
||||
).images.to(self.device_torch, dtype=dtype)
|
||||
# turn off progress bar
|
||||
self.sd.pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
current_timestep = timesteps_to
|
||||
pre_train = False
|
||||
|
||||
if not pre_train:
|
||||
# partially denoise the latents
|
||||
denoised_latents = self.sd.pipeline(
|
||||
num_inference_steps=self.train_config.max_denoising_steps,
|
||||
denoising_end=denoised_fraction,
|
||||
latents=latents,
|
||||
prompt_embeds=prompt.text_embeds,
|
||||
negative_prompt_embeds=neutral.text_embeds,
|
||||
pooled_prompt_embeds=prompt.pooled_embeds,
|
||||
negative_pooled_prompt_embeds=neutral.pooled_embeds,
|
||||
output_type="latent",
|
||||
num_images_per_prompt=self.train_config.batch_size,
|
||||
guidance_scale=3,
|
||||
).images.to(self.device_torch, dtype=dtype)
|
||||
current_timestep = timesteps_to
|
||||
|
||||
else:
|
||||
denoised_latents = latents
|
||||
current_timestep = 1
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1000
|
||||
)
|
||||
|
||||
from_prediction = self.sd.pipeline.predict_noise(
|
||||
latents=denoised_latents,
|
||||
@@ -213,10 +235,13 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||
pooled_prompt_embeds=prompt.pooled_embeds,
|
||||
negative_pooled_prompt_embeds=neutral.pooled_embeds,
|
||||
timestep=current_timestep,
|
||||
guidance_scale=2
|
||||
guidance_scale=1,
|
||||
num_images_per_prompt=self.train_config.batch_size,
|
||||
# predict_noise=True,
|
||||
num_inference_steps=1000,
|
||||
)
|
||||
|
||||
reduced_from_prediction = self.reduce_size_fn(from_prediction).to("cpu", dtype=torch.float32)
|
||||
reduced_from_prediction = self.reduce_size_fn(from_prediction)
|
||||
|
||||
# get noise prediction at reduced scale
|
||||
to_denoised_latents = self.reduce_size_fn(denoised_latents).to(self.device_torch, dtype=dtype)
|
||||
@@ -233,7 +258,10 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||
pooled_prompt_embeds=prompt.pooled_embeds,
|
||||
negative_pooled_prompt_embeds=neutral.pooled_embeds,
|
||||
timestep=current_timestep,
|
||||
guidance_scale=2
|
||||
guidance_scale=1,
|
||||
num_images_per_prompt=self.train_config.batch_size,
|
||||
# predict_noise=True,
|
||||
num_inference_steps=1000,
|
||||
)
|
||||
|
||||
reduced_from_prediction.requires_grad = False
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Union, List, Optional, Dict, Any, Tuple
|
||||
from typing import Union, List, Optional, Dict, Any, Tuple, Callable
|
||||
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
@@ -13,17 +13,32 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
denoising_end: Optional[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
timestep: Optional[int] = 1,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
# predict_noise: bool = False,
|
||||
timestep: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -35,6 +50,20 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in both text-encoders
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
denoising_end (`float`, *optional*):
|
||||
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
||||
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
||||
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
||||
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
||||
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
||||
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
@@ -48,6 +77,14 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
@@ -69,6 +106,15 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
@@ -78,14 +124,59 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
||||
timestep (`int`, *optional*, defaults to `1`):
|
||||
The timestep at which to generate the image. If not specified, the last timestep is used.
|
||||
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
||||
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
|
||||
explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
||||
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
||||
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
||||
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
|
||||
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
torch.FloatTensor: Predicted noise
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
# if not predict_noise:
|
||||
# # call parent
|
||||
# return super().__call__(
|
||||
# prompt=prompt,
|
||||
# prompt_2=prompt_2,
|
||||
# height=height,
|
||||
# width=width,
|
||||
# num_inference_steps=num_inference_steps,
|
||||
# denoising_end=denoising_end,
|
||||
# guidance_scale=guidance_scale,
|
||||
# negative_prompt=negative_prompt,
|
||||
# negative_prompt_2=negative_prompt_2,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
# eta=eta,
|
||||
# generator=generator,
|
||||
# latents=latents,
|
||||
# prompt_embeds=prompt_embeds,
|
||||
# negative_prompt_embeds=negative_prompt_embeds,
|
||||
# pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
# negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
# output_type=output_type,
|
||||
# return_dict=return_dict,
|
||||
# callback=callback,
|
||||
# callback_steps=callback_steps,
|
||||
# cross_attention_kwargs=cross_attention_kwargs,
|
||||
# guidance_rescale=guidance_rescale,
|
||||
# original_size=original_size,
|
||||
# crops_coords_top_left=crops_coords_top_left,
|
||||
# target_size=target_size,
|
||||
# )
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = self.default_sample_size * self.vae_scale_factor
|
||||
width = self.default_sample_size * self.vae_scale_factor
|
||||
@@ -106,16 +197,12 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
# do_classifier_free_guidance = guidance_scale > 1.0
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
|
||||
num_images_per_prompt = 1
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
@@ -137,7 +224,7 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(1, device=device)
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
@@ -150,16 +237,15 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
None,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 7. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0)
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
).to(device)
|
||||
).to(device) # TODO DOES NOT CAST ORIGINALLY
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
@@ -172,13 +258,13 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, timesteps)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
timestep=timestep,
|
||||
timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
|
||||
Reference in New Issue
Block a user