diff --git a/config/examples/train_lora_chroma_24gb.yaml b/config/examples/train_lora_chroma_24gb.yaml new file mode 100644 index 00000000..f3110652 --- /dev/null +++ b/config/examples/train_lora_chroma_24gb.yaml @@ -0,0 +1,97 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_chroma_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # chroma enjoys multiple resolutions + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with chroma + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new vell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for chroma, other dtypes may not work correctly + dtype: bf16 + model: + # Download the whichever model you prefer from the Chroma repo + # https://huggingface.co/lodestones/Chroma/tree/main + # point to it here. + name_or_path: "/path/to/chroma/chroma-unlocked-vVERSION.safetensors" + arch: "chroma" + quantize: true # run 8bit mixed precision + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # negative prompt, optional + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py new file mode 100644 index 00000000..90724450 --- /dev/null +++ b/extensions_built_in/diffusion_models/__init__.py @@ -0,0 +1,6 @@ +from .chroma import ChromaModel + +AI_TOOLKIT_MODELS = [ + # put a list of models here + ChromaModel +] diff --git a/extensions_built_in/diffusion_models/chroma/__init__.py b/extensions_built_in/diffusion_models/chroma/__init__.py new file mode 100644 index 00000000..b20e2f40 --- /dev/null +++ b/extensions_built_in/diffusion_models/chroma/__init__.py @@ -0,0 +1 @@ +from .chroma_model import ChromaModel \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/chroma/chroma_model.py b/extensions_built_in/diffusion_models/chroma/chroma_model.py new file mode 100644 index 00000000..d3a92049 --- /dev/null +++ b/extensions_built_in/diffusion_models/chroma/chroma_model.py @@ -0,0 +1,388 @@ +import os +from typing import TYPE_CHECKING + +import torch +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from diffusers import AutoencoderKL +# from toolkit.pixel_shuffle_encoder import AutoencoderPixelMixer +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.quantize import quantize, get_qtype +from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer +from .pipeline import ChromaPipeline +from einops import rearrange, repeat +import random +import torch.nn.functional as F +from .src.model import Chroma, chroma_params +from safetensors.torch import load_file, save_file +from toolkit.metadata import get_meta_for_safetensors + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": True +} + +class FakeConfig: + # for diffusers compatability + def __init__(self): + self.attention_head_dim = 128 + self.guidance_embeds = True + self.in_channels = 64 + self.joint_attention_dim = 4096 + self.num_attention_heads = 24 + self.num_layers = 19 + self.num_single_layers = 38 + self.patch_size = 1 + +class FakeCLIP(torch.nn.Module): + def __init__(self): + super().__init__() + self.dtype = torch.bfloat16 + self.device = 'cuda' + self.text_model = None + self.tokenizer = None + self.model_max_length = 77 + + def forward(self, *args, **kwargs): + return torch.zeros(1, 1, 1).to(self.device) + + +class ChromaModel(BaseModel): + arch = "chroma" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__( + device, + model_config, + dtype, + custom_pipeline, + noise_scheduler, + **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['Chroma'] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + # return the bucket divisibility for the model + return 32 + + def load_model(self): + dtype = self.torch_dtype + + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + + extras_path = 'black-forest-labs/FLUX.1-schnell' + + self.print_and_status_update("Loading transformer") + + transformer = Chroma(chroma_params) + + # add dtype, not sure why it doesnt have it + transformer.dtype = dtype + + chroma_state_dict = load_file(model_path, 'cpu') + # load the state dict into the model + transformer.load_state_dict(chroma_state_dict) + + transformer.to(self.quantize_device, dtype=dtype) + + transformer.config = FakeConfig() + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = get_qtype(self.model_config.qtype) + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, + **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + self.print_and_status_update("Loading T5") + tokenizer_2 = T5TokenizerFast.from_pretrained( + extras_path, subfolder="tokenizer_2", torch_dtype=dtype + ) + text_encoder_2 = T5EncoderModel.from_pretrained( + extras_path, subfolder="text_encoder_2", torch_dtype=dtype + ) + text_encoder_2.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing T5") + quantize(text_encoder_2, weights=get_qtype( + self.model_config.qtype)) + freeze(text_encoder_2) + flush() + + # self.print_and_status_update("Loading CLIP") + text_encoder = FakeCLIP() + tokenizer = FakeCLIP() + text_encoder.to(self.device_torch, dtype=dtype) + + self.noise_scheduler = ChromaModel.get_train_scheduler() + + self.print_and_status_update("Loading VAE") + vae = AutoencoderKL.from_pretrained( + extras_path, + subfolder="vae", + torch_dtype=dtype + ) + vae = vae.to(self.device_torch, dtype=dtype) + + self.print_and_status_update("Making pipe") + + pipe: ChromaPipeline = ChromaPipeline( + scheduler=self.noise_scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=None, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=None, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder_2 = text_encoder_2 + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + text_encoder[1].to(self.device_torch) + text_encoder[1].requires_grad_(False) + text_encoder[1].eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = ChromaModel.get_train_scheduler() + pipeline = ChromaPipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + text_encoder_2=unwrap_model(self.text_encoder[1]), + tokenizer_2=self.tokenizer[1], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer) + ) + + # pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: ChromaPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + + extra['negative_prompt_embeds'] = unconditional_embeds.text_embeds + extra['negative_prompt_attn_mask'] = unconditional_embeds.attention_mask + + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + prompt_attn_mask=conditional_embeds.attention_mask, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + with torch.no_grad(): + bs, c, h, w = latent_model_input.shape + latent_model_input_packed = rearrange( + latent_model_input, + "b c (h ph) (w pw) -> b (h w) (c ph pw)", + ph=2, + pw=2 + ) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", + b=bs).to(self.device_torch) + + txt_ids = torch.zeros( + bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) + + guidance = torch.full([1], 0, device=self.device_torch, dtype=torch.float32) + guidance = guidance.expand(latent_model_input_packed.shape[0]) + + cast_dtype = self.unet.dtype + + noise_pred = self.unet( + img=latent_model_input_packed.to( + self.device_torch, cast_dtype + ), + img_ids=img_ids, + txt=text_embeddings.text_embeds.to( + self.device_torch, cast_dtype + ), + txt_ids=txt_ids, + txt_mask=text_embeddings.attention_mask.to( + self.device_torch, cast_dtype + ), + timesteps=timestep / 1000, + guidance=guidance + ) + + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + + noise_pred = rearrange( + noise_pred, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=latent_model_input.shape[2] // 2, + w=latent_model_input.shape[3] // 2, + ph=2, + pw=2, + c=self.vae.config.latent_channels + ) + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if isinstance(prompt, str): + prompts = [prompt] + else: + prompts = prompt + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + max_length = 512 + + device = self.text_encoder[1].device + dtype = self.text_encoder[1].dtype + + # T5 + text_inputs = self.tokenizer[1]( + prompts, + padding="max_length", + max_length=max_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_embeds = self.text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder[1].dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + prompt_attention_mask = text_inputs["attention_mask"] + + pe = PromptEmbeds( + prompt_embeds + ) + pe.attention_mask = prompt_attention_mask + return pe + + def get_model_has_grad(self): + # return from a weight if it has grad + return self.model.final_layer.linear.weight.requires_grad + + def get_te_has_grad(self): + # return from a weight if it has grad + return self.text_encoder[1].encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + + def save_model(self, output_path, meta, save_dtype): + # only save the unet + transformer: Chroma = unwrap_model(self.model) + state_dict = transformer.state_dict() + save_dict = {} + for k, v in state_dict.items(): + if isinstance(v, QTensor): + v = v.dequantize() + save_dict[k] = v.clone().to('cpu', dtype=save_dtype) + + meta = get_meta_for_safetensors(meta, name='chroma') + save_file(save_dict, output_path, metadata=meta) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + batch = kwargs.get('batch') + return (noise - batch.latents).detach() + + def convert_lora_weights_before_save(self, state_dict): + # currently starte with transformer. but needs to start with diffusion_model. for comfyui + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + # saved as diffusion_model. but needs to be transformer. for ai-toolkit + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd diff --git a/extensions_built_in/diffusion_models/chroma/pipeline.py b/extensions_built_in/diffusion_models/chroma/pipeline.py new file mode 100644 index 00000000..52b9b817 --- /dev/null +++ b/extensions_built_in/diffusion_models/chroma/pipeline.py @@ -0,0 +1,195 @@ +from typing import Union, List, Optional, Dict, Any, Callable + +import numpy as np +import torch +from diffusers import FluxPipeline +from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from diffusers.utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +class ChromaPipeline(FluxPipeline): + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attn_mask: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attn_mask: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[ + int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=torch.bfloat16) + if guidance_scale > 1.00001: + negative_text_ids = torch.zeros(batch_size, negative_prompt_embeds.shape[1], 3).to(device=device, dtype=torch.bfloat16) + + # 4. Prepare latent variables + num_channels_latents = 64 // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # extend img ids to match batch size + latent_image_ids = latent_image_ids.unsqueeze(0) + latent_image_ids = torch.cat([latent_image_ids] * batch_size, dim=0) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + guidance = torch.full([1], 0, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # handle guidance + + noise_pred_text = self.transformer( + img=latents, + img_ids=latent_image_ids, + txt=prompt_embeds, + txt_ids=text_ids, + txt_mask=prompt_attn_mask, # todo add this + timesteps=timestep / 1000, + guidance=guidance + ) + + if guidance_scale > 1.00001: + noise_pred_uncond = self.transformer( + img=latents, + img_ids=latent_image_ids, + txt=negative_prompt_embeds, + txt_ids=negative_text_ids, + txt_mask=negative_prompt_attn_mask, # todo add this + timesteps=timestep / 1000, + guidance=guidance + ) + + noise_pred = noise_pred_uncond + self.guidance_scale * \ + (noise_pred_text - noise_pred_uncond) + + else: + noise_pred = noise_pred_text + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents( + latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + \ + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess( + image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/extensions_built_in/diffusion_models/chroma/src/__init__.py b/extensions_built_in/diffusion_models/chroma/src/__init__.py new file mode 100644 index 00000000..f3454029 --- /dev/null +++ b/extensions_built_in/diffusion_models/chroma/src/__init__.py @@ -0,0 +1 @@ +# This is taken and slightly modified from https://github.com/lodestone-rock/flow/tree/master/src/models/chroma \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/chroma/src/layers.py b/extensions_built_in/diffusion_models/chroma/src/layers.py new file mode 100644 index 00000000..726ec6a0 --- /dev/null +++ b/extensions_built_in/diffusion_models/chroma/src/layers.py @@ -0,0 +1,505 @@ +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn +import torch.nn.functional as F + +from .math import attention, rope + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, use_compiled: bool = False): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + self.use_compiled = use_compiled + + def _forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + def forward(self, x: Tensor): + return F.rms_norm(x, self.scale.shape, weight=self.scale, eps=1e-6) + # if self.use_compiled: + # return torch.compile(self._forward)(x) + # else: + # return self._forward(x) + + +def distribute_modulations(tensor: torch.Tensor): + """ + Distributes slices of the tensor into the block_dict as ModulationOut objects. + + Args: + tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim]. + """ + batch_size, vectors, dim = tensor.shape + + block_dict = {} + + # HARD CODED VALUES! lookup table for the generated vectors + # TODO: move this into chroma config! + # Add 38 single mod blocks + for i in range(38): + key = f"single_blocks.{i}.modulation.lin" + block_dict[key] = None + + # Add 19 image double blocks + for i in range(19): + key = f"double_blocks.{i}.img_mod.lin" + block_dict[key] = None + + # Add 19 text double blocks + for i in range(19): + key = f"double_blocks.{i}.txt_mod.lin" + block_dict[key] = None + + # Add the final layer + block_dict["final_layer.adaLN_modulation.1"] = None + # 6.2b version + block_dict["lite_double_blocks.4.img_mod.lin"] = None + block_dict["lite_double_blocks.4.txt_mod.lin"] = None + + idx = 0 # Index to keep track of the vector slices + + for key in block_dict.keys(): + if "single_blocks" in key: + # Single block: 1 ModulationOut + block_dict[key] = ModulationOut( + shift=tensor[:, idx : idx + 1, :], + scale=tensor[:, idx + 1 : idx + 2, :], + gate=tensor[:, idx + 2 : idx + 3, :], + ) + idx += 3 # Advance by 3 vectors + + elif "img_mod" in key: + # Double block: List of 2 ModulationOut + double_block = [] + for _ in range(2): # Create 2 ModulationOut objects + double_block.append( + ModulationOut( + shift=tensor[:, idx : idx + 1, :], + scale=tensor[:, idx + 1 : idx + 2, :], + gate=tensor[:, idx + 2 : idx + 3, :], + ) + ) + idx += 3 # Advance by 3 vectors per ModulationOut + block_dict[key] = double_block + + elif "txt_mod" in key: + # Double block: List of 2 ModulationOut + double_block = [] + for _ in range(2): # Create 2 ModulationOut objects + double_block.append( + ModulationOut( + shift=tensor[:, idx : idx + 1, :], + scale=tensor[:, idx + 1 : idx + 2, :], + gate=tensor[:, idx + 2 : idx + 3, :], + ) + ) + idx += 3 # Advance by 3 vectors per ModulationOut + block_dict[key] = double_block + + elif "final_layer" in key: + # Final layer: 1 ModulationOut + block_dict[key] = [ + tensor[:, idx : idx + 1, :], + tensor[:, idx + 1 : idx + 2, :], + ] + idx += 2 # Advance by 3 vectors + + return block_dict + + +class Approximator(nn.Module): + def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=4): + super().__init__() + self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True) + self.layers = nn.ModuleList( + [MLPEmbedder(hidden_dim, hidden_dim) for x in range(n_layers)] + ) + self.norms = nn.ModuleList([RMSNorm(hidden_dim) for x in range(n_layers)]) + self.out_proj = nn.Linear(hidden_dim, out_dim) + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def forward(self, x: Tensor) -> Tensor: + x = self.in_proj(x) + + for layer, norms in zip(self.layers, self.norms): + x = x + layer(norms(x)) + + x = self.out_proj(x) + + return x + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int, use_compiled: bool = False): + super().__init__() + self.query_norm = RMSNorm(dim, use_compiled=use_compiled) + self.key_norm = RMSNorm(dim, use_compiled=use_compiled) + self.use_compiled = use_compiled + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + use_compiled: bool = False, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim, use_compiled=use_compiled) + self.proj = nn.Linear(dim, dim) + self.use_compiled = use_compiled + + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +def _modulation_shift_scale_fn(x, scale, shift): + return (1 + scale) * x + shift + + +def _modulation_gate_fn(x, gate, gate_params): + return x + gate * gate_params + + +class DoubleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float, + qkv_bias: bool = False, + use_compiled: bool = False, + ): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_compiled=use_compiled, + ) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_compiled=use_compiled, + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + self.use_compiled = use_compiled + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def modulation_shift_scale_fn(self, x, scale, shift): + if self.use_compiled: + return torch.compile(_modulation_shift_scale_fn)(x, scale, shift) + else: + return _modulation_shift_scale_fn(x, scale, shift) + + def modulation_gate_fn(self, x, gate, gate_params): + if self.use_compiled: + return torch.compile(_modulation_gate_fn)(x, gate, gate_params) + else: + return _modulation_gate_fn(x, gate, gate_params) + + def forward( + self, + img: Tensor, + txt: Tensor, + pe: Tensor, + distill_vec: list[ModulationOut], + mask: Tensor, + ) -> tuple[Tensor, Tensor]: + (img_mod1, img_mod2), (txt_mod1, txt_mod2) = distill_vec + + # prepare image for attention + img_modulated = self.img_norm1(img) + # replaced with compiled fn + # img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_modulated = self.modulation_shift_scale_fn( + img_modulated, img_mod1.scale, img_mod1.shift + ) + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange( + img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + # replaced with compiled fn + # txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_modulated = self.modulation_shift_scale_fn( + txt_modulated, txt_mod1.scale, txt_mod1.shift + ) + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange( + txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe=pe, mask=mask) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + # replaced with compiled fn + # img = img + img_mod1.gate * self.img_attn.proj(img_attn) + # img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + img = self.modulation_gate_fn(img, img_mod1.gate, self.img_attn.proj(img_attn)) + img = self.modulation_gate_fn( + img, + img_mod2.gate, + self.img_mlp( + self.modulation_shift_scale_fn( + self.img_norm2(img), img_mod2.scale, img_mod2.shift + ) + ), + ) + + # calculate the txt bloks + # replaced with compiled fn + # txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + # txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + txt = self.modulation_gate_fn(txt, txt_mod1.gate, self.txt_attn.proj(txt_attn)) + txt = self.modulation_gate_fn( + txt, + txt_mod2.gate, + self.txt_mlp( + self.modulation_shift_scale_fn( + self.txt_norm2(txt), txt_mod2.scale, txt_mod2.shift + ) + ), + ) + + return img, txt + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + use_compiled: bool = False, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim, use_compiled=use_compiled) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.use_compiled = use_compiled + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def modulation_shift_scale_fn(self, x, scale, shift): + if self.use_compiled: + return torch.compile(_modulation_shift_scale_fn)(x, scale, shift) + else: + return _modulation_shift_scale_fn(x, scale, shift) + + def modulation_gate_fn(self, x, gate, gate_params): + if self.use_compiled: + return torch.compile(_modulation_gate_fn)(x, gate, gate_params) + else: + return _modulation_gate_fn(x, gate, gate_params) + + def forward( + self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], mask: Tensor + ) -> Tensor: + mod = distill_vec + # replaced with compiled fn + # x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + x_mod = self.modulation_shift_scale_fn(self.pre_norm(x), mod.scale, mod.shift) + qkv, mlp = torch.split( + self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1 + ) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe, mask=mask) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + # replaced with compiled fn + # return x + mod.gate * output + return self.modulation_gate_fn(x, mod.gate, output) + + +class LastLayer(nn.Module): + def __init__( + self, + hidden_size: int, + patch_size: int, + out_channels: int, + use_compiled: bool = False, + ): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + self.use_compiled = use_compiled + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def modulation_shift_scale_fn(self, x, scale, shift): + if self.use_compiled: + return torch.compile(_modulation_shift_scale_fn)(x, scale, shift) + else: + return _modulation_shift_scale_fn(x, scale, shift) + + def forward(self, x: Tensor, distill_vec: list[Tensor]) -> Tensor: + shift, scale = distill_vec + shift = shift.squeeze(1) + scale = scale.squeeze(1) + # replaced with compiled fn + # x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.modulation_shift_scale_fn( + self.norm_final(x), scale[:, None, :], shift[:, None, :] + ) + x = self.linear(x) + return x diff --git a/extensions_built_in/diffusion_models/chroma/src/math.py b/extensions_built_in/diffusion_models/chroma/src/math.py new file mode 100644 index 00000000..b46bca57 --- /dev/null +++ b/extensions_built_in/diffusion_models/chroma/src/math.py @@ -0,0 +1,33 @@ +import torch +from einops import rearrange +from torch import Tensor + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + # mask should have shape [B, H, L, D] + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack( + [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1 + ) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/extensions_built_in/diffusion_models/chroma/src/model.py b/extensions_built_in/diffusion_models/chroma/src/model.py new file mode 100644 index 00000000..3b6c29bb --- /dev/null +++ b/extensions_built_in/diffusion_models/chroma/src/model.py @@ -0,0 +1,273 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +import torch.utils.checkpoint as ckpt + +from .layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + SingleStreamBlock, + timestep_embedding, + Approximator, + distribute_modulations, +) + + +@dataclass +class ChromaParams: + in_channels: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + approximator_in_dim: int + approximator_depth: int + approximator_hidden_size: int + _use_compiled: bool + + +chroma_params = ChromaParams( + in_channels=64, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + approximator_in_dim=64, + approximator_depth=5, + approximator_hidden_size=5120, + _use_compiled=False, +) + + +def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8): + """ + Modifies attention mask to allow attention to a few extra padding tokens. + + Args: + mask: Original attention mask (1 for tokens to attend to, 0 for masked tokens) + max_seq_length: Maximum sequence length of the model + num_extra_padding: Number of padding tokens to unmask + + Returns: + Modified mask + """ + # Get the actual sequence length from the mask + seq_length = mask.sum(dim=-1) + batch_size = mask.shape[0] + + modified_mask = mask.clone() + + for i in range(batch_size): + current_seq_len = int(seq_length[i].item()) + + # Only add extra padding tokens if there's room + if current_seq_len < max_seq_length: + # Calculate how many padding tokens we can unmask + available_padding = max_seq_length - current_seq_len + tokens_to_unmask = min(num_extra_padding, available_padding) + + # Unmask the specified number of padding tokens right after the sequence + modified_mask[i, current_seq_len : current_seq_len + tokens_to_unmask] = 1 + + return modified_mask + + +class Chroma(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: ChromaParams): + super().__init__() + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError( + f"Got {params.axes_dim} but expected positional dim {pe_dim}" + ) + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND( + dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim + ) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + + # TODO: need proper mapping for this approximator output! + # currently the mapping is hardcoded in distribute_modulations function + self.distilled_guidance_layer = Approximator( + params.approximator_in_dim, + self.hidden_size, + params.approximator_hidden_size, + params.approximator_depth, + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + use_compiled=params._use_compiled, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + use_compiled=params._use_compiled, + ) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer( + self.hidden_size, + 1, + self.out_channels, + use_compiled=params._use_compiled, + ) + + # TODO: move this hardcoded value to config + self.mod_index_length = 344 + # self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0) + self.register_buffer( + "mod_index", + torch.tensor(list(range(self.mod_index_length)), device="cpu"), + persistent=False, + ) + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + txt_mask: Tensor, + timesteps: Tensor, + guidance: Tensor, + attn_padding: int = 1, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + txt = self.txt_in(txt) + + # TODO: + # need to fix grad accumulation issue here for now it's in no grad mode + # besides, i don't want to wash out the PFP that's trained on this model weights anyway + # the fan out operation here is deleting the backward graph + # alternatively doing forward pass for every block manually is doable but slow + # custom backward probably be better + with torch.no_grad(): + distill_timestep = timestep_embedding(timesteps, 16) + # TODO: need to add toggle to omit this from schnell but that's not a priority + distil_guidance = timestep_embedding(guidance, 16) + # get all modulation index + modulation_index = timestep_embedding(self.mod_index, 32) + # we need to broadcast the modulation index here so each batch has all of the index + modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1) + # and we need to broadcast timestep and guidance along too + timestep_guidance = ( + torch.cat([distill_timestep, distil_guidance], dim=1) + .unsqueeze(1) + .repeat(1, self.mod_index_length, 1) + ) + # then and only then we could concatenate it together + input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) + mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True)) + mod_vectors_dict = distribute_modulations(mod_vectors) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + # compute mask + # assume max seq length from the batched input + + max_len = txt.shape[1] + + # mask + with torch.no_grad(): + txt_mask_w_padding = modify_mask_to_attend_padding( + txt_mask, max_len, attn_padding + ) + txt_img_mask = torch.cat( + [ + txt_mask_w_padding, + torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device), + ], + dim=1, + ) + txt_img_mask = txt_img_mask.float().T @ txt_img_mask.float() + txt_img_mask = ( + txt_img_mask[None, None, ...] + .repeat(txt.shape[0], self.num_heads, 1, 1) + .int() + .bool() + ) + # txt_mask_w_padding[txt_mask_w_padding==False] = True + + for i, block in enumerate(self.double_blocks): + # the guidance replaced by FFN output + img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] + txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] + double_mod = [img_mod, txt_mod] + + # just in case in different GPU for simple pipeline parallel + if self.training: + img.requires_grad_(True) + img, txt = ckpt.checkpoint( + block, img, txt, pe, double_mod, txt_img_mask + ) + else: + img, txt = block( + img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask + ) + + img = torch.cat((txt, img), 1) + for i, block in enumerate(self.single_blocks): + single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] + if self.training: + img.requires_grad_(True) + img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask) + else: + img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask) + img = img[:, txt.shape[1] :, ...] + final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] + img = self.final_layer( + img, distill_vec=final_mod + ) # (N, T, patch_size ** 2 * out_channels) + return img