From b95c17dc177002d1c1ffe71f95c280a21502e28c Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 10 Sep 2025 08:41:05 -0600 Subject: [PATCH] Add initial support for chroma radiance --- .../diffusion_models/__init__.py | 3 +- .../diffusion_models/chroma/__init__.py | 3 +- .../diffusion_models/chroma/chroma_model.py | 19 +- .../chroma/chroma_radiance_model.py | 445 ++++++++++++++++++ .../diffusion_models/chroma/pipeline.py | 139 +++++- .../diffusion_models/chroma/src/layers.py | 226 ++++++++- .../diffusion_models/chroma/src/model.py | 10 +- .../diffusion_models/chroma/src/radiance.py | 380 +++++++++++++++ toolkit/models/FakeVAE.py | 134 ++++++ 9 files changed, 1339 insertions(+), 20 deletions(-) create mode 100644 extensions_built_in/diffusion_models/chroma/chroma_radiance_model.py create mode 100644 extensions_built_in/diffusion_models/chroma/src/radiance.py create mode 100644 toolkit/models/FakeVAE.py diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index 5682bc46..6449fe61 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -1,4 +1,4 @@ -from .chroma import ChromaModel +from .chroma import ChromaModel, ChromaRadianceModel from .hidream import HidreamModel, HidreamE1Model from .f_light import FLiteModel from .omnigen2 import OmniGen2Model @@ -9,6 +9,7 @@ from .qwen_image import QwenImageModel, QwenImageEditModel AI_TOOLKIT_MODELS = [ # put a list of models here ChromaModel, + ChromaRadianceModel, HidreamModel, HidreamE1Model, FLiteModel, diff --git a/extensions_built_in/diffusion_models/chroma/__init__.py b/extensions_built_in/diffusion_models/chroma/__init__.py index b20e2f40..c34866a6 100644 --- a/extensions_built_in/diffusion_models/chroma/__init__.py +++ b/extensions_built_in/diffusion_models/chroma/__init__.py @@ -1 +1,2 @@ -from .chroma_model import ChromaModel \ No newline at end of file +from .chroma_model import ChromaModel +from .chroma_radiance_model import ChromaRadianceModel \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/chroma/chroma_model.py b/extensions_built_in/diffusion_models/chroma/chroma_model.py index 9bf3e51e..236d9508 100644 --- a/extensions_built_in/diffusion_models/chroma/chroma_model.py +++ b/extensions_built_in/diffusion_models/chroma/chroma_model.py @@ -15,7 +15,7 @@ from toolkit.accelerator import unwrap_model from optimum.quanto import freeze, QTensor from toolkit.util.quantize import quantize, get_qtype from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer -from .pipeline import ChromaPipeline +from .pipeline import ChromaPipeline, prepare_latent_image_ids from einops import rearrange, repeat import random import torch.nn.functional as F @@ -324,12 +324,19 @@ class ChromaModel(BaseModel): ph=2, pw=2 ) + + img_ids = prepare_latent_image_ids( + bs, + h, + w, + patch_size=2 + ).to(device=self.device_torch) - img_ids = torch.zeros(h // 2, w // 2, 3) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] - img_ids = repeat(img_ids, "h w c -> b (h w) c", - b=bs).to(self.device_torch) + # img_ids = torch.zeros(h // 2, w // 2, 3) + # img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + # img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + # img_ids = repeat(img_ids, "h w c -> b (h w) c", + # b=bs).to(self.device_torch) txt_ids = torch.zeros( bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) diff --git a/extensions_built_in/diffusion_models/chroma/chroma_radiance_model.py b/extensions_built_in/diffusion_models/chroma/chroma_radiance_model.py new file mode 100644 index 00000000..333600e8 --- /dev/null +++ b/extensions_built_in/diffusion_models/chroma/chroma_radiance_model.py @@ -0,0 +1,445 @@ +import os +from typing import TYPE_CHECKING + +import torch +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from diffusers import AutoencoderKL +# from toolkit.pixel_shuffle_encoder import AutoencoderPixelMixer +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.quantize import quantize, get_qtype +from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer +from .pipeline import ChromaPipeline, prepare_latent_image_ids +from einops import rearrange, repeat +import random +import torch.nn.functional as F +from .src.radiance import Chroma, chroma_params +from safetensors.torch import load_file, save_file +from toolkit.metadata import get_meta_for_safetensors +from toolkit.models.FakeVAE import FakeVAE +import huggingface_hub + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": True +} + +class FakeConfig: + # for diffusers compatability + def __init__(self): + self.attention_head_dim = 128 + self.guidance_embeds = True + self.in_channels = 64 + self.joint_attention_dim = 4096 + self.num_attention_heads = 24 + self.num_layers = 19 + self.num_single_layers = 38 + self.patch_size = 1 + +class FakeCLIP(torch.nn.Module): + def __init__(self): + super().__init__() + self.dtype = torch.bfloat16 + self.device = 'cuda' + self.text_model = None + self.tokenizer = None + self.model_max_length = 77 + + def forward(self, *args, **kwargs): + return torch.zeros(1, 1, 1).to(self.device) + + +class ChromaRadianceModel(BaseModel): + arch = "chroma_radiance" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__( + device, + model_config, + dtype, + custom_pipeline, + noise_scheduler, + **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['Chroma'] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + # return the bucket divisibility for the model + return 32 + + def load_model(self): + dtype = self.torch_dtype + + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + + if model_path == "lodestones/Chroma": + print("Looking for latest Chroma checkpoint") + # get the latest checkpoint + files_list = huggingface_hub.list_repo_files(model_path) + print(files_list) + latest_version = 28 # current latest version at time of writing + while True: + if f"chroma-unlocked-v{latest_version}.safetensors" not in files_list: + latest_version -= 1 + break + else: + latest_version += 1 + print(f"Using latest Chroma version: v{latest_version}") + + # make sure we have it + model_path = huggingface_hub.hf_hub_download( + repo_id=model_path, + filename=f"chroma-unlocked-v{latest_version}.safetensors", + ) + elif model_path.startswith("lodestones/Chroma/v"): + # get the version number + version = model_path.split("/")[-1].split("v")[-1] + print(f"Using Chroma version: v{version}") + # make sure we have it + model_path = huggingface_hub.hf_hub_download( + repo_id='lodestones/Chroma', + filename=f"chroma-unlocked-v{version}.safetensors", + ) + elif model_path.startswith("lodestones/Chroma1-"): + # will have a file in the repo that is Chroma1-whatever.safetensors + model_path = huggingface_hub.hf_hub_download( + repo_id=model_path, + filename=f"{model_path.split('/')[-1]}.safetensors", + ) + + else: + # check if the model path is a local file + if os.path.exists(model_path): + print(f"Using local model: {model_path}") + else: + raise ValueError(f"Model path {model_path} does not exist") + + # extras_path = 'black-forest-labs/FLUX.1-schnell' + # schnell model is gated now, use flex instead + extras_path = 'ostris/Flex.1-alpha' + + self.print_and_status_update("Loading transformer") + + if model_path.endswith('.pth') or model_path.endswith('.pt'): + chroma_state_dict = torch.load(model_path, map_location='cpu', weights_only=True) + else: + chroma_state_dict = load_file(model_path, 'cpu') + + # determine number of double and single blocks + double_blocks = 0 + single_blocks = 0 + for key in chroma_state_dict.keys(): + if "double_blocks" in key: + block_num = int(key.split(".")[1]) + 1 + if block_num > double_blocks: + double_blocks = block_num + elif "single_blocks" in key: + block_num = int(key.split(".")[1]) + 1 + if block_num > single_blocks: + single_blocks = block_num + print(f"Double Blocks: {double_blocks}") + print(f"Single Blocks: {single_blocks}") + + chroma_params.depth = double_blocks + chroma_params.depth_single_blocks = single_blocks + transformer = Chroma(chroma_params) + + # add dtype, not sure why it doesnt have it + transformer.dtype = dtype + # load the state dict into the model + transformer.load_state_dict(chroma_state_dict) + + transformer.to(self.quantize_device, dtype=dtype) + + transformer.config = FakeConfig() + transformer.config.num_layers = double_blocks + transformer.config.num_single_layers = single_blocks + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = get_qtype(self.model_config.qtype) + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, + **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + self.print_and_status_update("Loading T5") + tokenizer_2 = T5TokenizerFast.from_pretrained( + extras_path, subfolder="tokenizer_2", torch_dtype=dtype + ) + text_encoder_2 = T5EncoderModel.from_pretrained( + extras_path, subfolder="text_encoder_2", torch_dtype=dtype + ) + text_encoder_2.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing T5") + quantize(text_encoder_2, weights=get_qtype( + self.model_config.qtype)) + freeze(text_encoder_2) + flush() + + # self.print_and_status_update("Loading CLIP") + text_encoder = FakeCLIP() + tokenizer = FakeCLIP() + text_encoder.to(self.device_torch, dtype=dtype) + + self.noise_scheduler = ChromaRadianceModel.get_train_scheduler() + + self.print_and_status_update("Loading VAE") + # vae = AutoencoderKL.from_pretrained( + # extras_path, + # subfolder="vae", + # torch_dtype=dtype + # ) + vae = FakeVAE() + vae = vae.to(self.device_torch, dtype=dtype) + + self.print_and_status_update("Making pipe") + + pipe: ChromaPipeline = ChromaPipeline( + scheduler=self.noise_scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=None, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=None, + is_radiance=True, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder_2 = text_encoder_2 + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + text_encoder[1].to(self.device_torch) + text_encoder[1].requires_grad_(False) + text_encoder[1].eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = ChromaRadianceModel.get_train_scheduler() + pipeline = ChromaPipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + text_encoder_2=unwrap_model(self.text_encoder[1]), + tokenizer_2=self.tokenizer[1], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + is_radiance=True, + ) + + # pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: ChromaPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + + extra['negative_prompt_embeds'] = unconditional_embeds.text_embeds + extra['negative_prompt_attn_mask'] = unconditional_embeds.attention_mask + + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + prompt_attn_mask=conditional_embeds.attention_mask, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + with torch.no_grad(): + bs, c, h, w = latent_model_input.shape + + img_ids = prepare_latent_image_ids( + bs, h, w, patch_size=16 + ).to(self.device_torch) + + txt_ids = torch.zeros( + bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) + + guidance = torch.full([1], 0, device=self.device_torch, dtype=torch.float32) + guidance = guidance.expand(bs) + + cast_dtype = self.unet.dtype + + noise_pred = self.unet( + img=latent_model_input.to( + self.device_torch, cast_dtype + ), + img_ids=img_ids, + txt=text_embeddings.text_embeds.to( + self.device_torch, cast_dtype + ), + txt_ids=txt_ids, + txt_mask=text_embeddings.attention_mask.to( + self.device_torch, cast_dtype + ), + timesteps=timestep / 1000, + guidance=guidance + ) + + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if isinstance(prompt, str): + prompts = [prompt] + else: + prompts = prompt + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + max_length = 512 + + device = self.text_encoder[1].device + dtype = self.text_encoder[1].dtype + + # T5 + text_inputs = self.tokenizer[1]( + prompts, + padding="max_length", + max_length=max_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_embeds = self.text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder[1].dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + prompt_attention_mask = text_inputs["attention_mask"] + + pe = PromptEmbeds( + prompt_embeds + ) + pe.attention_mask = prompt_attention_mask + return pe + + def get_model_has_grad(self): + # return from a weight if it has grad + return False + def get_te_has_grad(self): + # return from a weight if it has grad + return False + + def save_model(self, output_path, meta, save_dtype): + if not output_path.endswith(".safetensors"): + output_path = output_path + ".safetensors" + # only save the unet + transformer: Chroma = unwrap_model(self.model) + state_dict = transformer.state_dict() + save_dict = {} + for k, v in state_dict.items(): + if isinstance(v, QTensor): + v = v.dequantize() + save_dict[k] = v.clone().to('cpu', dtype=save_dtype) + + meta = get_meta_for_safetensors(meta, name='chroma') + save_file(save_dict, output_path, metadata=meta) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + batch = kwargs.get('batch') + return (noise - batch.latents).detach() + + def convert_lora_weights_before_save(self, state_dict): + # currently starte with transformer. but needs to start with diffusion_model. for comfyui + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + # saved as diffusion_model. but needs to be transformer. for ai-toolkit + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd + + def get_base_model_version(self): + return "chroma_radiance" diff --git a/extensions_built_in/diffusion_models/chroma/pipeline.py b/extensions_built_in/diffusion_models/chroma/pipeline.py index 215be798..5a76a713 100644 --- a/extensions_built_in/diffusion_models/chroma/pipeline.py +++ b/extensions_built_in/diffusion_models/chroma/pipeline.py @@ -6,6 +6,7 @@ from diffusers import FluxPipeline from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput from diffusers.utils import is_torch_xla_available +from diffusers.utils.torch_utils import randn_tensor if is_torch_xla_available(): @@ -16,7 +17,134 @@ else: XLA_AVAILABLE = False +def prepare_latent_image_ids(batch_size, height, width, patch_size=2, max_offset=0): + """ + Generates positional embeddings for a latent image. + + Args: + batch_size (int): The number of images in the batch. + height (int): The height of the image. + width (int): The width of the image. + patch_size (int, optional): The size of the patches. Defaults to 2. + max_offset (int, optional): The maximum random offset to apply. Defaults to 0. + + Returns: + torch.Tensor: A tensor containing the positional embeddings. + """ + # the random pos embedding helps generalize to larger res without training at large res + # pos embedding for rope, 2d pos embedding, corner embedding and not center based + latent_image_ids = torch.zeros(height // patch_size, width // patch_size, 3) + + # Add positional encodings + latent_image_ids[..., 1] = ( + latent_image_ids[..., 1] + torch.arange(height // patch_size)[:, None] + ) + latent_image_ids[..., 2] = ( + latent_image_ids[..., 2] + torch.arange(width // patch_size)[None, :] + ) + + # Add random offset if specified + if max_offset > 0: + offset_y = torch.randint(0, max_offset + 1, (1,)).item() + offset_x = torch.randint(0, max_offset + 1, (1,)).item() + latent_image_ids[..., 1] += offset_y + latent_image_ids[..., 2] += offset_x + + + ( + latent_image_id_height, + latent_image_id_width, + latent_image_id_channels, + ) = latent_image_ids.shape + + # Reshape for batch + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, + latent_image_id_height * latent_image_id_width, + latent_image_id_channels, + ) + + return latent_image_ids + + class ChromaPipeline(FluxPipeline): + def __init__( + self, + scheduler, + vae, + text_encoder, + tokenizer, + text_encoder_2, + tokenizer_2, + transformer, + image_encoder = None, + feature_extractor = None, + is_radiance: bool = False, + ): + super().__init__( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + transformer=transformer, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.is_radiance = is_radiance + self.vae_scale_factor = 8 if not is_radiance else 1 + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = prepare_latent_image_ids( + batch_size, + height, + width, + patch_size=2 if not self.is_radiance else 16 + ).to(device=device, dtype=dtype) + # latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + if not self.is_radiance: + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + # latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + latent_image_ids = prepare_latent_image_ids( + batch_size, + height, + width, + patch_size=2 if not self.is_radiance else 16 + ).to(device=device, dtype=dtype) + + return latents, latent_image_ids + def __call__( self, prompt: Union[str, List[str]] = None, @@ -70,6 +198,8 @@ class ChromaPipeline(FluxPipeline): # 4. Prepare latent variables num_channels_latents = 64 // 4 + if self.is_radiance: + num_channels_latents = 3 latents, latent_image_ids = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -82,8 +212,8 @@ class ChromaPipeline(FluxPipeline): ) # extend img ids to match batch size - latent_image_ids = latent_image_ids.unsqueeze(0) - latent_image_ids = torch.cat([latent_image_ids] * batch_size, dim=0) + # latent_image_ids = latent_image_ids.unsqueeze(0) + # latent_image_ids = torch.cat([latent_image_ids] * batch_size, dim=0) # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) @@ -180,8 +310,9 @@ class ChromaPipeline(FluxPipeline): image = latents else: - latents = self._unpack_latents( - latents, height, width, self.vae_scale_factor) + if not self.is_radiance: + latents = self._unpack_latents( + latents, height, width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + \ self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] diff --git a/extensions_built_in/diffusion_models/chroma/src/layers.py b/extensions_built_in/diffusion_models/chroma/src/layers.py index 726ec6a0..030a57cd 100644 --- a/extensions_built_in/diffusion_models/chroma/src/layers.py +++ b/extensions_built_in/diffusion_models/chroma/src/layers.py @@ -7,6 +7,7 @@ from torch import Tensor, nn import torch.nn.functional as F from .math import attention, rope +from functools import lru_cache class EmbedND(nn.Module): @@ -88,7 +89,7 @@ class RMSNorm(torch.nn.Module): # return self._forward(x) -def distribute_modulations(tensor: torch.Tensor): +def distribute_modulations(tensor: torch.Tensor, depth_single_blocks, depth_double_blocks): """ Distributes slices of the tensor into the block_dict as ModulationOut objects. @@ -102,25 +103,25 @@ def distribute_modulations(tensor: torch.Tensor): # HARD CODED VALUES! lookup table for the generated vectors # TODO: move this into chroma config! # Add 38 single mod blocks - for i in range(38): + for i in range(depth_single_blocks): key = f"single_blocks.{i}.modulation.lin" block_dict[key] = None # Add 19 image double blocks - for i in range(19): + for i in range(depth_double_blocks): key = f"double_blocks.{i}.img_mod.lin" block_dict[key] = None # Add 19 text double blocks - for i in range(19): + for i in range(depth_double_blocks): key = f"double_blocks.{i}.txt_mod.lin" block_dict[key] = None # Add the final layer block_dict["final_layer.adaLN_modulation.1"] = None # 6.2b version - block_dict["lite_double_blocks.4.img_mod.lin"] = None - block_dict["lite_double_blocks.4.txt_mod.lin"] = None + # block_dict["lite_double_blocks.4.img_mod.lin"] = None + # block_dict["lite_double_blocks.4.txt_mod.lin"] = None idx = 0 # Index to keep track of the vector slices @@ -173,6 +174,219 @@ def distribute_modulations(tensor: torch.Tensor): return block_dict + +class NerfEmbedder(nn.Module): + """ + An embedder module that combines input features with a 2D positional + encoding that mimics the Discrete Cosine Transform (DCT). + + This module takes an input tensor of shape (B, P^2, C), where P is the + patch size, and enriches it with positional information before projecting + it to a new hidden size. + """ + def __init__(self, in_channels, hidden_size_input, max_freqs): + """ + Initializes the NerfEmbedder. + + Args: + in_channels (int): The number of channels in the input tensor. + hidden_size_input (int): The desired dimension of the output embedding. + max_freqs (int): The number of frequency components to use for both + the x and y dimensions of the positional encoding. + The total number of positional features will be max_freqs^2. + """ + super().__init__() + self.max_freqs = max_freqs + self.hidden_size_input = hidden_size_input + + # A linear layer to project the concatenated input features and + # positional encodings to the final output dimension. + self.embedder = nn.Sequential( + nn.Linear(in_channels + max_freqs**2, hidden_size_input) + ) + + @lru_cache(maxsize=4) + def fetch_pos(self, patch_size, device, dtype): + """ + Generates and caches 2D DCT-like positional embeddings for a given patch size. + + The LRU cache is a performance optimization that avoids recomputing the + same positional grid on every forward pass. + + Args: + patch_size (int): The side length of the square input patch. + device: The torch device to create the tensors on. + dtype: The torch dtype for the tensors. + + Returns: + A tensor of shape (1, patch_size^2, max_freqs^2) containing the + positional embeddings. + """ + # Create normalized 1D coordinate grids from 0 to 1. + pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + + # Create a 2D meshgrid of coordinates. + pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") + + # Reshape positions to be broadcastable with frequencies. + # Shape becomes (patch_size^2, 1, 1). + pos_x = pos_x.reshape(-1, 1, 1) + pos_y = pos_y.reshape(-1, 1, 1) + + # Create a 1D tensor of frequency values from 0 to max_freqs-1. + freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device) + + # Reshape frequencies to be broadcastable for creating 2D basis functions. + # freqs_x shape: (1, max_freqs, 1) + # freqs_y shape: (1, 1, max_freqs) + freqs_x = freqs[None, :, None] + freqs_y = freqs[None, None, :] + + # A custom weighting coefficient, not part of standard DCT. + # This seems to down-weight the contribution of higher-frequency interactions. + coeffs = (1 + freqs_x * freqs_y) ** -1 + + # Calculate the 1D cosine basis functions for x and y coordinates. + # This is the core of the DCT formulation. + dct_x = torch.cos(pos_x * freqs_x * torch.pi) + dct_y = torch.cos(pos_y * freqs_y * torch.pi) + + # Combine the 1D basis functions to create 2D basis functions by element-wise + # multiplication, and apply the custom coefficients. Broadcasting handles the + # combination of all (pos_x, freqs_x) with all (pos_y, freqs_y). + # The result is flattened into a feature vector for each position. + dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2) + + return dct + + def forward(self, inputs): + """ + Forward pass for the embedder. + + Args: + inputs (Tensor): The input tensor of shape (B, P^2, C). + + Returns: + Tensor: The output tensor of shape (B, P^2, hidden_size_input). + """ + # Get the batch size, number of pixels, and number of channels. + B, P2, C = inputs.shape + # Store the original dtype to cast back to at the end. + original_dtype = inputs.dtype + # Force all operations within this module to run in fp32. + with torch.autocast("cuda", enabled=False): + # Infer the patch side length from the number of pixels (P^2). + patch_size = int(P2 ** 0.5) + + inputs = inputs.float() + # Fetch the pre-computed or cached positional embeddings. + dct = self.fetch_pos(patch_size, inputs.device, torch.float32) + + # Repeat the positional embeddings for each item in the batch. + dct = dct.repeat(B, 1, 1) + + # Concatenate the original input features with the positional embeddings + # along the feature dimension. + inputs = torch.cat([inputs, dct], dim=-1) + + # Project the combined tensor to the target hidden size. + inputs = self.embedder.float()(inputs) + + return inputs.to(original_dtype) + + + +class NerfGLUBlock(nn.Module): + """ + A NerfBlock using a Gated Linear Unit (GLU) like MLP. + """ + def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio, use_compiled): + super().__init__() + # The total number of parameters for the MLP is increased to accommodate + # the gate, value, and output projection matrices. + # We now need to generate parameters for 3 matrices. + total_params = 3 * hidden_size_x**2 * mlp_ratio + self.param_generator = nn.Linear(hidden_size_s, total_params) + self.norm = RMSNorm(hidden_size_x, use_compiled) + self.mlp_ratio = mlp_ratio + # nn.init.zeros_(self.param_generator.weight) + # nn.init.zeros_(self.param_generator.bias) + + + def forward(self, x, s): + batch_size, num_x, hidden_size_x = x.shape + mlp_params = self.param_generator(s) + + # Split the generated parameters into three parts for the gate, value, and output projection. + fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1) + + # Reshape the parameters into matrices for batch matrix multiplication. + fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x) + + # Normalize the generated weight matrices as in the original implementation. + fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2) + fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2) + fc2 = torch.nn.functional.normalize(fc2, dim=-2) + + res_x = x + x = self.norm(x) + + # Apply the final output projection. + x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2) + + x = x + res_x + return x + + +class NerfFinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels, use_compiled): + super().__init__() + self.norm = RMSNorm(hidden_size, use_compiled=use_compiled) + self.linear = nn.Linear(hidden_size, out_channels) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x): + x = self.norm(x) + x = self.linear(x) + return x + + +class NerfFinalLayerConv(nn.Module): + def __init__(self, hidden_size, out_channels, use_compiled): + super().__init__() + self.norm = RMSNorm(hidden_size, use_compiled=use_compiled) + + # replace nn.Linear with nn.Conv2d since linear is just pointwise conv + self.conv = nn.Conv2d( + in_channels=hidden_size, + out_channels=out_channels, + kernel_size=3, + padding=1 + ) + nn.init.zeros_(self.conv.weight) + nn.init.zeros_(self.conv.bias) + + def forward(self, x): + # shape: [N, C, H, W] ! + # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. + # So, we permute the dimensions to make the channel dimension the last one. + x_permuted = x.permute(0, 2, 3, 1) # Shape becomes [N, H, W, C] + + # Apply normalization on the feature/channel dimension + x_norm = self.norm(x_permuted) + + # Permute back to the original dimension order for the convolution + x_norm_permuted = x_norm.permute(0, 3, 1, 2) # Shape becomes [N, C, H, W] + + # Apply the 3x3 convolution + x = self.conv(x_norm_permuted) + return x + + class Approximator(nn.Module): def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=4): super().__init__() diff --git a/extensions_built_in/diffusion_models/chroma/src/model.py b/extensions_built_in/diffusion_models/chroma/src/model.py index 33cdbe62..ebdf69d9 100644 --- a/extensions_built_in/diffusion_models/chroma/src/model.py +++ b/extensions_built_in/diffusion_models/chroma/src/model.py @@ -156,13 +156,19 @@ class Chroma(nn.Module): ) # TODO: move this hardcoded value to config - self.mod_index_length = 344 + # single layer has 3 modulation vectors + # double layer has 6 modulation vectors for each expert + # final layer has 2 modulation vectors + self.mod_index_length = 3 * params.depth_single_blocks + 2 * 6 * params.depth + 2 + self.depth_single_blocks = params.depth_single_blocks + self.depth_double_blocks = params.depth # self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0) self.register_buffer( "mod_index", torch.tensor(list(range(self.mod_index_length)), device="cpu"), persistent=False, ) + self.approximator_in_dim = params.approximator_in_dim @property def device(self): @@ -213,7 +219,7 @@ class Chroma(nn.Module): # then and only then we could concatenate it together input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True)) - mod_vectors_dict = distribute_modulations(mod_vectors) + mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks) ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) diff --git a/extensions_built_in/diffusion_models/chroma/src/radiance.py b/extensions_built_in/diffusion_models/chroma/src/radiance.py new file mode 100644 index 00000000..d328f261 --- /dev/null +++ b/extensions_built_in/diffusion_models/chroma/src/radiance.py @@ -0,0 +1,380 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +import torch.utils.checkpoint as ckpt + +from .layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + SingleStreamBlock, + timestep_embedding, + Approximator, + distribute_modulations, + NerfEmbedder, + NerfFinalLayer, + NerfFinalLayerConv, + NerfGLUBlock +) + + +@dataclass +class ChromaParams: + in_channels: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + approximator_in_dim: int + approximator_depth: int + approximator_hidden_size: int + patch_size: int + nerf_hidden_size: int + nerf_mlp_ratio: int + nerf_depth: int + nerf_max_freqs: int + _use_compiled: bool + + +chroma_params = ChromaParams( + in_channels=3, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + approximator_in_dim=64, + approximator_depth=5, + approximator_hidden_size=5120, + patch_size=16, + nerf_hidden_size=64, + nerf_mlp_ratio=4, + nerf_depth=4, + nerf_max_freqs=8, + _use_compiled=False, +) + + +def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8): + """ + Modifies attention mask to allow attention to a few extra padding tokens. + + Args: + mask: Original attention mask (1 for tokens to attend to, 0 for masked tokens) + max_seq_length: Maximum sequence length of the model + num_extra_padding: Number of padding tokens to unmask + + Returns: + Modified mask + """ + # Get the actual sequence length from the mask + seq_length = mask.sum(dim=-1) + batch_size = mask.shape[0] + + modified_mask = mask.clone() + + for i in range(batch_size): + current_seq_len = int(seq_length[i].item()) + + # Only add extra padding tokens if there's room + if current_seq_len < max_seq_length: + # Calculate how many padding tokens we can unmask + available_padding = max_seq_length - current_seq_len + tokens_to_unmask = min(num_extra_padding, available_padding) + + # Unmask the specified number of padding tokens right after the sequence + modified_mask[i, current_seq_len : current_seq_len + tokens_to_unmask] = 1 + + return modified_mask + + +class Chroma(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: ChromaParams): + super().__init__() + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + self.gradient_checkpointing = False + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError( + f"Got {params.axes_dim} but expected positional dim {pe_dim}" + ) + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND( + dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim + ) + # self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + # patchify ops + self.img_in_patch = nn.Conv2d( + params.in_channels, + params.hidden_size, + kernel_size=params.patch_size, + stride=params.patch_size, + bias=True + ) + nn.init.zeros_(self.img_in_patch.weight) + nn.init.zeros_(self.img_in_patch.bias) + # TODO: need proper mapping for this approximator output! + # currently the mapping is hardcoded in distribute_modulations function + self.distilled_guidance_layer = Approximator( + params.approximator_in_dim, + self.hidden_size, + params.approximator_hidden_size, + params.approximator_depth, + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + use_compiled=params._use_compiled, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + use_compiled=params._use_compiled, + ) + for _ in range(params.depth_single_blocks) + ] + ) + + # self.final_layer = LastLayer( + # self.hidden_size, + # 1, + # self.out_channels, + # use_compiled=params._use_compiled, + # ) + + # pixel channel concat with DCT + self.nerf_image_embedder = NerfEmbedder( + in_channels=params.in_channels, + hidden_size_input=params.nerf_hidden_size, + max_freqs=params.nerf_max_freqs + ) + + self.nerf_blocks = nn.ModuleList([ + NerfGLUBlock( + hidden_size_s=params.hidden_size, + hidden_size_x=params.nerf_hidden_size, + mlp_ratio=params.nerf_mlp_ratio, + use_compiled=params._use_compiled + ) for _ in range(params.nerf_depth) + ]) + # self.nerf_final_layer = NerfFinalLayer( + # params.nerf_hidden_size, + # out_channels=params.in_channels, + # use_compiled=params._use_compiled + # ) + self.nerf_final_layer_conv = NerfFinalLayerConv( + params.nerf_hidden_size, + out_channels=params.in_channels, + use_compiled=params._use_compiled + ) + # TODO: move this hardcoded value to config + # single layer has 3 modulation vectors + # double layer has 6 modulation vectors for each expert + # final layer has 2 modulation vectors + self.mod_index_length = 3 * params.depth_single_blocks + 2 * 6 * params.depth + 2 + self.depth_single_blocks = params.depth_single_blocks + self.depth_double_blocks = params.depth + # self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0) + self.register_buffer( + "mod_index", + torch.tensor(list(range(self.mod_index_length)), device="cpu"), + persistent=False, + ) + self.approximator_in_dim = params.approximator_in_dim + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def enable_gradient_checkpointing(self, enable: bool = True): + self.gradient_checkpointing = enable + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + txt_mask: Tensor, + timesteps: Tensor, + guidance: Tensor, + attn_padding: int = 1, + ) -> Tensor: + if img.ndim != 4: + raise ValueError("Input img tensor must be in [B, C, H, W] format.") + if txt.ndim != 3: + raise ValueError("Input txt tensors must have 3 dimensions.") + B, C, H, W = img.shape + + # gemini gogogo idk how to unfold and pack the patch properly :P + # Store the raw pixel values of each patch for the NeRF head later. + # unfold creates patches: [B, C * P * P, NumPatches] + nerf_pixels = nn.functional.unfold(img, kernel_size=self.params.patch_size, stride=self.params.patch_size) + nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P] + + # partchify ops + img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P] + num_patches = img.shape[2] * img.shape[3] + # flatten into a sequence for the transformer. + img = img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden] + + txt = self.txt_in(txt) + + # TODO: + # need to fix grad accumulation issue here for now it's in no grad mode + # besides, i don't want to wash out the PFP that's trained on this model weights anyway + # the fan out operation here is deleting the backward graph + # alternatively doing forward pass for every block manually is doable but slow + # custom backward probably be better + with torch.no_grad(): + distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim//4) + # TODO: need to add toggle to omit this from schnell but that's not a priority + distil_guidance = timestep_embedding(guidance, self.approximator_in_dim//4) + # get all modulation index + modulation_index = timestep_embedding(self.mod_index, self.approximator_in_dim//2) + # we need to broadcast the modulation index here so each batch has all of the index + modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1) + # and we need to broadcast timestep and guidance along too + timestep_guidance = ( + torch.cat([distill_timestep, distil_guidance], dim=1) + .unsqueeze(1) + .repeat(1, self.mod_index_length, 1) + ) + # then and only then we could concatenate it together + input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) + mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True)) + mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + # compute mask + # assume max seq length from the batched input + + max_len = txt.shape[1] + + # mask + with torch.no_grad(): + txt_mask_w_padding = modify_mask_to_attend_padding( + txt_mask, max_len, attn_padding + ) + txt_img_mask = torch.cat( + [ + txt_mask_w_padding, + torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device), + ], + dim=1, + ) + txt_img_mask = txt_img_mask.float().T @ txt_img_mask.float() + txt_img_mask = ( + txt_img_mask[None, None, ...] + .repeat(txt.shape[0], self.num_heads, 1, 1) + .int() + .bool() + ) + # txt_mask_w_padding[txt_mask_w_padding==False] = True + + for i, block in enumerate(self.double_blocks): + # the guidance replaced by FFN output + img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] + txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] + double_mod = [img_mod, txt_mod] + + # just in case in different GPU for simple pipeline parallel + if torch.is_grad_enabled() and self.gradient_checkpointing: + img.requires_grad_(True) + img, txt = ckpt.checkpoint( + block, img, txt, pe, double_mod, txt_img_mask + ) + else: + img, txt = block( + img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask + ) + + img = torch.cat((txt, img), 1) + for i, block in enumerate(self.single_blocks): + single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] + if torch.is_grad_enabled() and self.gradient_checkpointing: + img.requires_grad_(True) + img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask) + else: + img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask) + img = img[:, txt.shape[1] :, ...] + + # final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] + # img = self.final_layer( + # img, distill_vec=final_mod + # ) # (N, T, patch_size ** 2 * out_channels) + + # aliasing + nerf_hidden = img + # reshape for per-patch processing + nerf_hidden = nerf_hidden.reshape(B * num_patches, self.params.hidden_size) + nerf_pixels = nerf_pixels.reshape(B * num_patches, C, self.params.patch_size**2).transpose(1, 2) + + # get DCT-encoded pixel embeddings [pixel-dct] + img_dct = self.nerf_image_embedder(nerf_pixels) + + # pass through the dynamic MLP blocks (the NeRF) + for i, block in enumerate(self.nerf_blocks): + if self.training: + img_dct = ckpt.checkpoint(block, img_dct, nerf_hidden) + else: + img_dct = block(img_dct, nerf_hidden) + + # final projection to get the output pixel values + # img_dct = self.nerf_final_layer(img_dct) # -> [B*NumPatches, P*P, C] + img_dct = self.nerf_final_layer_conv.norm(img_dct) + + # gemini gogogo idk how to fold this properly :P + # Reassemble the patches into the final image. + img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P] + # Reshape to combine with batch dimension for fold + img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P] + img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches] + img_dct = nn.functional.fold( + img_dct, + output_size=(H, W), + kernel_size=self.params.patch_size, + stride=self.params.patch_size + ) # [B, Hidden, H, W] + img_dct = self.nerf_final_layer_conv.conv(img_dct) + + return img_dct \ No newline at end of file diff --git a/toolkit/models/FakeVAE.py b/toolkit/models/FakeVAE.py new file mode 100644 index 00000000..86ca4730 --- /dev/null +++ b/toolkit/models/FakeVAE.py @@ -0,0 +1,134 @@ +from diffusers import AutoencoderKL +from typing import Optional, Union +import torch +import torch.nn as nn +import numpy as np +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKLOutput +from diffusers.models.autoencoders.vae import DecoderOutput + + +class Config: + in_channels = 3 + out_channels = 3 + down_block_types = ("1",) + up_block_types = ("1",) + block_out_channels = (1,) + latent_channels = 3 # usually 4 + norm_num_groups = 1 + sample_size = 512 + scaling_factor = 1.0 + # scaling_factor = 1.8 + shift_factor = 0 + + def __getitem__(cls, x): + return getattr(cls, x) + + +class FakeVAE(nn.Module): + def __init__(self): + super().__init__() + self._dtype = torch.float32 + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.config = Config() + + @property + def dtype(self): + return self._dtype + + @dtype.setter + def dtype(self, value): + self._dtype = value + + @property + def device(self): + return self._device + + @device.setter + def device(self, value): + self._device = value + + # mimic to from torch + def to(self, *args, **kwargs): + # pull out dtype and device if they exist + if "dtype" in kwargs: + self._dtype = kwargs["dtype"] + if "device" in kwargs: + self._device = kwargs["device"] + return super().to(*args, **kwargs) + + def enable_xformers_memory_efficient_attention(self): + pass + + # @apply_forward_hook + def encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> AutoencoderKLOutput: + h = x + + # moments = self.quant_conv(h) + # posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (h,) + + class FakeDist: + def __init__(self, x): + self._sample = x + + def sample(self): + return self._sample + + return AutoencoderKLOutput(latent_dist=FakeDist(h)) + + def _decode( + self, z: torch.FloatTensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + dec = z + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + # @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def _set_gradient_checkpointing(self, module, value=False): + pass + + def enable_tiling(self, use_tiling: bool = True): + pass + + def disable_tiling(self): + pass + + def enable_slicing(self): + pass + + def disable_slicing(self): + pass + + def set_use_memory_efficient_attention_xformers(self, value: bool = True): + pass + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + dec = sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec)