diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index e4105ff8..8e9cc12a 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -13,10 +13,12 @@ from toolkit.optimizer import get_optimizer from toolkit.paths import REPOS_ROOT import sys +from toolkit.pipelines import CustomStableDiffusionXLPipeline + sys.path.append(REPOS_ROOT) sys.path.append(os.path.join(REPOS_ROOT, 'leco')) -from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DDPMScheduler +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline from jobs.process import BaseTrainProcess from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors @@ -100,15 +102,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # self.sd.tokenizer.to(self.device_torch) # TODO add clip skip if self.sd.is_xl: - pipeline = StableDiffusionXLPipeline( - vae=self.sd.vae, - unet=self.sd.unet, - text_encoder=self.sd.text_encoder[0], - text_encoder_2=self.sd.text_encoder[1], - tokenizer=self.sd.tokenizer[0], - tokenizer_2=self.sd.tokenizer[1], - scheduler=self.sd.noise_scheduler, - ) + pipeline = self.sd.pipeline else: pipeline = StableDiffusionPipeline( vae=self.sd.vae, @@ -209,7 +203,8 @@ class BaseSDTrainProcess(BaseTrainProcess): img.save(output_path) # clear pipeline and cache to reduce vram usage - del pipeline + if not self.sd.is_xl: + del pipeline torch.cuda.empty_cache() # restore training state @@ -363,6 +358,20 @@ class BaseSDTrainProcess(BaseTrainProcess): else: return None + def predict_noise_xl( + self, + latents: torch.FloatTensor, + positive_prompt: str, + negative_prompt: str, + timestep: int, + guidance_scale=7.5, + guidance_rescale=0.7, + add_time_ids=None, + **kwargs, + ): + pass + + def predict_noise( self, latents: torch.FloatTensor, @@ -492,12 +501,12 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.model_config.is_xl: - pipe = StableDiffusionXLPipeline.from_single_file( + pipe = CustomStableDiffusionXLPipeline.from_single_file( self.model_config.name_or_path, dtype=dtype, - scheduler_type='pndm', + scheduler_type='dpm', device=self.device_torch - ) + ).to(self.device_torch) text_encoders = [pipe.text_encoder, pipe.text_encoder_2] tokenizer = [pipe.tokenizer, pipe.tokenizer_2] unet = pipe.unet @@ -513,7 +522,6 @@ class BaseSDTrainProcess(BaseTrainProcess): text_encoder = text_encoders tokenizer = tokenizer - del pipe flush() @@ -529,6 +537,7 @@ class BaseSDTrainProcess(BaseTrainProcess): text_encoder.eval() vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype) vae.eval() + pipe = None flush() @@ -536,7 +545,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # put on cpu for now, we only need it when sampling # vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype) # vae.eval() - self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler, is_xl=self.model_config.is_xl) + self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler, is_xl=self.model_config.is_xl, pipeline=pipe) unet.to(self.device_torch, dtype=dtype) if self.train_config.xformers: diff --git a/jobs/process/TrainSDRescaleProcess.py b/jobs/process/TrainSDRescaleProcess.py index 582d6ffa..3c45170a 100644 --- a/jobs/process/TrainSDRescaleProcess.py +++ b/jobs/process/TrainSDRescaleProcess.py @@ -157,33 +157,19 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): torch.randint(0, len(self.prompt_txt_list), (1,)).item() ] prompt = self.prompt_cache[prompt_txt].to(device=self.device_torch, dtype=dtype) + prompt.text_embeds.to(device=self.device_torch, dtype=dtype) + prompt.pooled_embeds.to(device=self.device_torch, dtype=dtype) neutral = self.prompt_cache[""].to(device=self.device_torch, dtype=dtype) + neutral.text_embeds.to(device=self.device_torch, dtype=dtype) + neutral.pooled_embeds.to(device=self.device_torch, dtype=dtype) if prompt is None: raise ValueError(f"Prompt {prompt_txt} is not in cache") - prompt_batch = train_tools.concat_prompt_embeddings( - prompt, - neutral, - self.train_config.batch_size, - ) - noise_scheduler = self.sd.noise_scheduler optimizer = self.optimizer lr_scheduler = self.lr_scheduler loss_function = torch.nn.MSELoss() - def get_noise_pred(p, n, gs, cts, dn): - return self.predict_noise( - latents=dn, - text_embeddings=train_tools.concat_prompt_embeddings( - p, # unconditional - n, # positive - self.train_config.batch_size, - ), - timestep=cts, - guidance_scale=gs, - ) - with torch.no_grad(): self.sd.noise_scheduler.set_timesteps( self.train_config.max_denoising_steps, device=self.device_torch @@ -195,52 +181,60 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): timesteps_to = torch.randint( 1, self.train_config.max_denoising_steps, (1,) ).item() + absolute_total_timesteps = 1000 # get noise - noise = self.get_latent_noise( + latents = self.get_latent_noise( pixel_height=self.rescale_config.from_resolution, pixel_width=self.rescale_config.from_resolution, ).to(self.device_torch, dtype=dtype) - # get latents - latents = noise * self.sd.noise_scheduler.init_noise_sigma - latents = latents.to(self.device_torch, dtype=dtype) - # - # # predict without network - # assert self.network.is_active is False - # denoised_latents = self.diffuse_some_steps( - # latents, # pass simple noise latents - # prompt_batch, - # start_timesteps=0, - # total_timesteps=timesteps_to, - # guidance_scale=3, - # ) - # noise_scheduler.set_timesteps(1000) - # - # current_timestep = noise_scheduler.timesteps[ - # int(timesteps_to * 1000 / self.train_config.max_denoising_steps) - # ] + denoised_fraction = timesteps_to / absolute_total_timesteps - current_timestep = 0 - denoised_latents = latents - # get noise prediction at full scale - from_prediction = get_noise_pred( - prompt, neutral, 1, current_timestep, denoised_latents + denoised_latents = self.sd.pipeline( + num_inference_steps=1000, + denoising_end=denoised_fraction, + latents=latents, + prompt_embeds=prompt.text_embeds, + negative_prompt_embeds=neutral.text_embeds, + pooled_prompt_embeds=prompt.pooled_embeds, + negative_pooled_prompt_embeds=neutral.pooled_embeds, + output_type="latent", + num_images_per_prompt=self.train_config.batch_size, + guidance_scale=3, + ).images.to(self.device_torch, dtype=dtype) + + current_timestep = timesteps_to + + from_prediction = self.sd.pipeline.predict_noise( + latents=denoised_latents, + prompt_embeds=prompt.text_embeds, + negative_prompt_embeds=neutral.text_embeds, + pooled_prompt_embeds=prompt.pooled_embeds, + negative_pooled_prompt_embeds=neutral.pooled_embeds, + timestep=current_timestep, + guidance_scale=2 ) reduced_from_prediction = self.reduce_size_fn(from_prediction).to("cpu", dtype=torch.float32) # get noise prediction at reduced scale - to_denoised_latents = self.reduce_size_fn(denoised_latents) + to_denoised_latents = self.reduce_size_fn(denoised_latents).to(self.device_torch, dtype=dtype) # start gradient optimizer.zero_grad() self.network.multiplier = 1.0 with self.network: assert self.network.is_active is True - to_prediction = get_noise_pred( - prompt, neutral, 1, current_timestep, to_denoised_latents - ).to("cpu", dtype=torch.float32) + to_prediction = self.sd.pipeline.predict_noise( + latents=to_denoised_latents, + prompt_embeds=prompt.text_embeds, + negative_prompt_embeds=neutral.text_embeds, + pooled_prompt_embeds=prompt.pooled_embeds, + negative_pooled_prompt_embeds=neutral.pooled_embeds, + timestep=current_timestep, + guidance_scale=2 + ) reduced_from_prediction.requires_grad = False from_prediction.requires_grad = False diff --git a/toolkit/pipelines.py b/toolkit/pipelines.py new file mode 100644 index 00000000..69d915c9 --- /dev/null +++ b/toolkit/pipelines.py @@ -0,0 +1,202 @@ +from typing import Union, List, Optional, Dict, Any, Tuple + +import torch +from diffusers import StableDiffusionXLPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg + + +class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): + # def __init__(self, *args, **kwargs): + # super().__init__(*args, **kwargs) + + def predict_noise( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + timestep: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + timestep (`int`, *optional*, defaults to `1`): + The timestep at which to generate the image. If not specified, the last timestep is used. + + Examples: + + Returns: + torch.FloatTensor: Predicted noise + """ + # 0. Default height and width to unet + height = self.default_sample_size * self.vae_scale_factor + width = self.default_sample_size * self.vae_scale_factor + + original_size = (height, width) + target_size = (height, width) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + # do_classifier_free_guidance = guidance_scale > 1.0 + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + + num_images_per_prompt = 1 + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(1, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + None, + latents, + ) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + crops_coords_top_left: Tuple[int, int] = (0, 0) + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ).to(device) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timesteps) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + return noise_pred + + def enable_model_cpu_offload(self, gpu_id=0): + print('Called cpu offload', gpu_id) + # fuck off + pass diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 9b1f6c09..a763a988 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1,3 +1,4 @@ +import typing from typing import Union, OrderedDict import sys import os @@ -36,7 +37,15 @@ class PromptEmbeds: return self +# if is type checking +if typing.TYPE_CHECKING: + from diffusers import StableDiffusionPipeline + from toolkit.pipelines import CustomStableDiffusionXLPipeline + + class StableDiffusion: + pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline'] + def __init__( self, vae, @@ -44,7 +53,8 @@ class StableDiffusion: text_encoder, unet, noise_scheduler, - is_xl=False + is_xl=False, + pipeline=None, ): # text encoder has a list of 2 for xl self.vae = vae @@ -53,6 +63,7 @@ class StableDiffusion: self.unet = unet self.noise_scheduler = noise_scheduler self.is_xl = is_xl + self.pipeline = pipeline def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds: prompt = prompt