From 596e57a6a6d2b4842296b55cae50535f84cb42d2 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 27 Jul 2023 11:24:33 -0600 Subject: [PATCH] Pipelines working on SDXL for noise prediction --- jobs/process/BaseSDTrainProcess.py | 19 +++-- jobs/process/TrainSDRescaleProcess.py | 66 ++++++++++----- toolkit/pipelines.py | 116 ++++++++++++++++++++++---- 3 files changed, 162 insertions(+), 39 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 8e9cc12a..85e39a23 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -18,7 +18,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline sys.path.append(REPOS_ROOT) sys.path.append(os.path.join(REPOS_ROOT, 'leco')) -from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2DiscreteScheduler from jobs.process import BaseTrainProcess from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors @@ -500,13 +500,21 @@ class BaseSDTrainProcess(BaseTrainProcess): dtype = get_torch_dtype(self.train_config.dtype) if self.model_config.is_xl: + # do our own scheduler + scheduler = KDPM2DiscreteScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.0120, + beta_schedule="scaled_linear", + ) pipe = CustomStableDiffusionXLPipeline.from_single_file( self.model_config.name_or_path, dtype=dtype, scheduler_type='dpm', - device=self.device_torch + device=self.device_torch, ).to(self.device_torch) + pipe.scheduler = scheduler text_encoders = [pipe.text_encoder, pipe.text_encoder_2] tokenizer = [pipe.tokenizer, pipe.tokenizer_2] unet = pipe.unet @@ -637,10 +645,11 @@ class BaseSDTrainProcess(BaseTrainProcess): self.progress_bar = tqdm( total=self.train_config.steps, desc=self.job.name, - leave=True + leave=True, + initial=self.step_num, + iterable=range(0, self.train_config.steps), ) - # set it to our current step in case it was updated from a load - self.progress_bar.update(self.step_num) + # self.step_num = 0 for step in range(self.step_num, self.train_config.steps): # todo handle dataloader here maybe, not sure diff --git a/jobs/process/TrainSDRescaleProcess.py b/jobs/process/TrainSDRescaleProcess.py index 3c45170a..86b39f32 100644 --- a/jobs/process/TrainSDRescaleProcess.py +++ b/jobs/process/TrainSDRescaleProcess.py @@ -171,9 +171,9 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): loss_function = torch.nn.MSELoss() with torch.no_grad(): - self.sd.noise_scheduler.set_timesteps( - self.train_config.max_denoising_steps, device=self.device_torch - ) + # self.sd.noise_scheduler.set_timesteps( + # self.train_config.max_denoising_steps, device=self.device_torch + # ) self.optimizer.zero_grad() @@ -183,6 +183,12 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): ).item() absolute_total_timesteps = 1000 + max_len_timestep_str = len(str(self.train_config.max_denoising_steps)) + # pad with spaces + timestep_str = str(timesteps_to).rjust(max_len_timestep_str, " ") + new_description = f"{self.job.name} ts: {timestep_str}" + self.progress_bar.set_description(new_description) + # get noise latents = self.get_latent_noise( pixel_height=self.rescale_config.from_resolution, @@ -190,21 +196,37 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): ).to(self.device_torch, dtype=dtype) denoised_fraction = timesteps_to / absolute_total_timesteps + self.sd.pipeline.to(self.device_torch) + torch.set_default_device(self.device_torch) - denoised_latents = self.sd.pipeline( - num_inference_steps=1000, - denoising_end=denoised_fraction, - latents=latents, - prompt_embeds=prompt.text_embeds, - negative_prompt_embeds=neutral.text_embeds, - pooled_prompt_embeds=prompt.pooled_embeds, - negative_pooled_prompt_embeds=neutral.pooled_embeds, - output_type="latent", - num_images_per_prompt=self.train_config.batch_size, - guidance_scale=3, - ).images.to(self.device_torch, dtype=dtype) + # turn off progress bar + self.sd.pipeline.set_progress_bar_config(disable=True) - current_timestep = timesteps_to + pre_train = False + + if not pre_train: + # partially denoise the latents + denoised_latents = self.sd.pipeline( + num_inference_steps=self.train_config.max_denoising_steps, + denoising_end=denoised_fraction, + latents=latents, + prompt_embeds=prompt.text_embeds, + negative_prompt_embeds=neutral.text_embeds, + pooled_prompt_embeds=prompt.pooled_embeds, + negative_pooled_prompt_embeds=neutral.pooled_embeds, + output_type="latent", + num_images_per_prompt=self.train_config.batch_size, + guidance_scale=3, + ).images.to(self.device_torch, dtype=dtype) + current_timestep = timesteps_to + + else: + denoised_latents = latents + current_timestep = 1 + + self.sd.noise_scheduler.set_timesteps( + 1000 + ) from_prediction = self.sd.pipeline.predict_noise( latents=denoised_latents, @@ -213,10 +235,13 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): pooled_prompt_embeds=prompt.pooled_embeds, negative_pooled_prompt_embeds=neutral.pooled_embeds, timestep=current_timestep, - guidance_scale=2 + guidance_scale=1, + num_images_per_prompt=self.train_config.batch_size, + # predict_noise=True, + num_inference_steps=1000, ) - reduced_from_prediction = self.reduce_size_fn(from_prediction).to("cpu", dtype=torch.float32) + reduced_from_prediction = self.reduce_size_fn(from_prediction) # get noise prediction at reduced scale to_denoised_latents = self.reduce_size_fn(denoised_latents).to(self.device_torch, dtype=dtype) @@ -233,7 +258,10 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): pooled_prompt_embeds=prompt.pooled_embeds, negative_pooled_prompt_embeds=neutral.pooled_embeds, timestep=current_timestep, - guidance_scale=2 + guidance_scale=1, + num_images_per_prompt=self.train_config.batch_size, + # predict_noise=True, + num_inference_steps=1000, ) reduced_from_prediction.requires_grad = False diff --git a/toolkit/pipelines.py b/toolkit/pipelines.py index 69d915c9..c65d6df1 100644 --- a/toolkit/pipelines.py +++ b/toolkit/pipelines.py @@ -1,4 +1,4 @@ -from typing import Union, List, Optional, Dict, Any, Tuple +from typing import Union, List, Optional, Dict, Any, Tuple, Callable import torch from diffusers import StableDiffusionXLPipeline @@ -13,17 +13,32 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, - timestep: Optional[int] = 1, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + # predict_noise: bool = False, + timestep: Optional[int] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -35,6 +50,20 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -48,6 +77,14 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents @@ -69,6 +106,15 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -78,14 +124,59 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. - timestep (`int`, *optional*, defaults to `1`): - The timestep at which to generate the image. If not specified, the last timestep is used. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Examples: Returns: - torch.FloatTensor: Predicted noise + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. """ + # if not predict_noise: + # # call parent + # return super().__call__( + # prompt=prompt, + # prompt_2=prompt_2, + # height=height, + # width=width, + # num_inference_steps=num_inference_steps, + # denoising_end=denoising_end, + # guidance_scale=guidance_scale, + # negative_prompt=negative_prompt, + # negative_prompt_2=negative_prompt_2, + # num_images_per_prompt=num_images_per_prompt, + # eta=eta, + # generator=generator, + # latents=latents, + # prompt_embeds=prompt_embeds, + # negative_prompt_embeds=negative_prompt_embeds, + # pooled_prompt_embeds=pooled_prompt_embeds, + # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + # output_type=output_type, + # return_dict=return_dict, + # callback=callback, + # callback_steps=callback_steps, + # cross_attention_kwargs=cross_attention_kwargs, + # guidance_rescale=guidance_rescale, + # original_size=original_size, + # crops_coords_top_left=crops_coords_top_left, + # target_size=target_size, + # ) + # 0. Default height and width to unet height = self.default_sample_size * self.vae_scale_factor width = self.default_sample_size * self.vae_scale_factor @@ -106,16 +197,12 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. - # do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - - num_images_per_prompt = 1 - ( prompt_embeds, negative_prompt_embeds, @@ -137,7 +224,7 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): ) # 4. Prepare timesteps - self.scheduler.set_timesteps(1, device=device) + self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps @@ -150,16 +237,15 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): width, prompt_embeds.dtype, device, - None, + generator, latents, ) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds - crops_coords_top_left: Tuple[int, int] = (0, 0) add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype - ).to(device) + ).to(device) # TODO DOES NOT CAST ORIGINALLY if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) @@ -172,13 +258,13 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, timesteps) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} noise_pred = self.unet( latent_model_input, - timestep=timestep, + timestep, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs,