Add initial support for chroma radiance

This commit is contained in:
Jaret Burkett
2025-09-10 08:41:05 -06:00
parent af6fdaaaf9
commit b95c17dc17
9 changed files with 1339 additions and 20 deletions

View File

@@ -1,4 +1,4 @@
from .chroma import ChromaModel
from .chroma import ChromaModel, ChromaRadianceModel
from .hidream import HidreamModel, HidreamE1Model
from .f_light import FLiteModel
from .omnigen2 import OmniGen2Model
@@ -9,6 +9,7 @@ from .qwen_image import QwenImageModel, QwenImageEditModel
AI_TOOLKIT_MODELS = [
# put a list of models here
ChromaModel,
ChromaRadianceModel,
HidreamModel,
HidreamE1Model,
FLiteModel,

View File

@@ -1 +1,2 @@
from .chroma_model import ChromaModel
from .chroma_model import ChromaModel
from .chroma_radiance_model import ChromaRadianceModel

View File

@@ -15,7 +15,7 @@ from toolkit.accelerator import unwrap_model
from optimum.quanto import freeze, QTensor
from toolkit.util.quantize import quantize, get_qtype
from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer
from .pipeline import ChromaPipeline
from .pipeline import ChromaPipeline, prepare_latent_image_ids
from einops import rearrange, repeat
import random
import torch.nn.functional as F
@@ -324,12 +324,19 @@ class ChromaModel(BaseModel):
ph=2,
pw=2
)
img_ids = prepare_latent_image_ids(
bs,
h,
w,
patch_size=2
).to(device=self.device_torch)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c",
b=bs).to(self.device_torch)
# img_ids = torch.zeros(h // 2, w // 2, 3)
# img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
# img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
# img_ids = repeat(img_ids, "h w c -> b (h w) c",
# b=bs).to(self.device_torch)
txt_ids = torch.zeros(
bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)

View File

@@ -0,0 +1,445 @@
import os
from typing import TYPE_CHECKING
import torch
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from PIL import Image
from toolkit.models.base_model import BaseModel
from toolkit.basic import flush
from diffusers import AutoencoderKL
# from toolkit.pixel_shuffle_encoder import AutoencoderPixelMixer
from toolkit.prompt_utils import PromptEmbeds
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
from toolkit.dequantize import patch_dequantization_on_save
from toolkit.accelerator import unwrap_model
from optimum.quanto import freeze, QTensor
from toolkit.util.quantize import quantize, get_qtype
from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer
from .pipeline import ChromaPipeline, prepare_latent_image_ids
from einops import rearrange, repeat
import random
import torch.nn.functional as F
from .src.radiance import Chroma, chroma_params
from safetensors.torch import load_file, save_file
from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.FakeVAE import FakeVAE
import huggingface_hub
if TYPE_CHECKING:
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": 0.5,
"max_image_seq_len": 4096,
"max_shift": 1.15,
"num_train_timesteps": 1000,
"shift": 3.0,
"use_dynamic_shifting": True
}
class FakeConfig:
# for diffusers compatability
def __init__(self):
self.attention_head_dim = 128
self.guidance_embeds = True
self.in_channels = 64
self.joint_attention_dim = 4096
self.num_attention_heads = 24
self.num_layers = 19
self.num_single_layers = 38
self.patch_size = 1
class FakeCLIP(torch.nn.Module):
def __init__(self):
super().__init__()
self.dtype = torch.bfloat16
self.device = 'cuda'
self.text_model = None
self.tokenizer = None
self.model_max_length = 77
def forward(self, *args, **kwargs):
return torch.zeros(1, 1, 1).to(self.device)
class ChromaRadianceModel(BaseModel):
arch = "chroma_radiance"
def __init__(
self,
device,
model_config: ModelConfig,
dtype='bf16',
custom_pipeline=None,
noise_scheduler=None,
**kwargs
):
super().__init__(
device,
model_config,
dtype,
custom_pipeline,
noise_scheduler,
**kwargs
)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ['Chroma']
# static method to get the noise scheduler
@staticmethod
def get_train_scheduler():
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
def get_bucket_divisibility(self):
# return the bucket divisibility for the model
return 32
def load_model(self):
dtype = self.torch_dtype
# will be updated if we detect a existing checkpoint in training folder
model_path = self.model_config.name_or_path
if model_path == "lodestones/Chroma":
print("Looking for latest Chroma checkpoint")
# get the latest checkpoint
files_list = huggingface_hub.list_repo_files(model_path)
print(files_list)
latest_version = 28 # current latest version at time of writing
while True:
if f"chroma-unlocked-v{latest_version}.safetensors" not in files_list:
latest_version -= 1
break
else:
latest_version += 1
print(f"Using latest Chroma version: v{latest_version}")
# make sure we have it
model_path = huggingface_hub.hf_hub_download(
repo_id=model_path,
filename=f"chroma-unlocked-v{latest_version}.safetensors",
)
elif model_path.startswith("lodestones/Chroma/v"):
# get the version number
version = model_path.split("/")[-1].split("v")[-1]
print(f"Using Chroma version: v{version}")
# make sure we have it
model_path = huggingface_hub.hf_hub_download(
repo_id='lodestones/Chroma',
filename=f"chroma-unlocked-v{version}.safetensors",
)
elif model_path.startswith("lodestones/Chroma1-"):
# will have a file in the repo that is Chroma1-whatever.safetensors
model_path = huggingface_hub.hf_hub_download(
repo_id=model_path,
filename=f"{model_path.split('/')[-1]}.safetensors",
)
else:
# check if the model path is a local file
if os.path.exists(model_path):
print(f"Using local model: {model_path}")
else:
raise ValueError(f"Model path {model_path} does not exist")
# extras_path = 'black-forest-labs/FLUX.1-schnell'
# schnell model is gated now, use flex instead
extras_path = 'ostris/Flex.1-alpha'
self.print_and_status_update("Loading transformer")
if model_path.endswith('.pth') or model_path.endswith('.pt'):
chroma_state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
else:
chroma_state_dict = load_file(model_path, 'cpu')
# determine number of double and single blocks
double_blocks = 0
single_blocks = 0
for key in chroma_state_dict.keys():
if "double_blocks" in key:
block_num = int(key.split(".")[1]) + 1
if block_num > double_blocks:
double_blocks = block_num
elif "single_blocks" in key:
block_num = int(key.split(".")[1]) + 1
if block_num > single_blocks:
single_blocks = block_num
print(f"Double Blocks: {double_blocks}")
print(f"Single Blocks: {single_blocks}")
chroma_params.depth = double_blocks
chroma_params.depth_single_blocks = single_blocks
transformer = Chroma(chroma_params)
# add dtype, not sure why it doesnt have it
transformer.dtype = dtype
# load the state dict into the model
transformer.load_state_dict(chroma_state_dict)
transformer.to(self.quantize_device, dtype=dtype)
transformer.config = FakeConfig()
transformer.config.num_layers = double_blocks
transformer.config.num_single_layers = single_blocks
if self.model_config.quantize:
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = get_qtype(self.model_config.qtype)
self.print_and_status_update("Quantizing transformer")
quantize(transformer, weights=quantization_type,
**self.model_config.quantize_kwargs)
freeze(transformer)
transformer.to(self.device_torch)
else:
transformer.to(self.device_torch, dtype=dtype)
flush()
self.print_and_status_update("Loading T5")
tokenizer_2 = T5TokenizerFast.from_pretrained(
extras_path, subfolder="tokenizer_2", torch_dtype=dtype
)
text_encoder_2 = T5EncoderModel.from_pretrained(
extras_path, subfolder="text_encoder_2", torch_dtype=dtype
)
text_encoder_2.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing T5")
quantize(text_encoder_2, weights=get_qtype(
self.model_config.qtype))
freeze(text_encoder_2)
flush()
# self.print_and_status_update("Loading CLIP")
text_encoder = FakeCLIP()
tokenizer = FakeCLIP()
text_encoder.to(self.device_torch, dtype=dtype)
self.noise_scheduler = ChromaRadianceModel.get_train_scheduler()
self.print_and_status_update("Loading VAE")
# vae = AutoencoderKL.from_pretrained(
# extras_path,
# subfolder="vae",
# torch_dtype=dtype
# )
vae = FakeVAE()
vae = vae.to(self.device_torch, dtype=dtype)
self.print_and_status_update("Making pipe")
pipe: ChromaPipeline = ChromaPipeline(
scheduler=self.noise_scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=None,
tokenizer_2=tokenizer_2,
vae=vae,
transformer=None,
is_radiance=True,
)
# for quantization, it works best to do these after making the pipe
pipe.text_encoder_2 = text_encoder_2
pipe.transformer = transformer
self.print_and_status_update("Preparing Model")
text_encoder = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
# just to make sure everything is on the right device and dtype
text_encoder[0].to(self.device_torch)
text_encoder[0].requires_grad_(False)
text_encoder[0].eval()
text_encoder[1].to(self.device_torch)
text_encoder[1].requires_grad_(False)
text_encoder[1].eval()
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
# save it to the model class
self.vae = vae
self.text_encoder = text_encoder # list of text encoders
self.tokenizer = tokenizer # list of tokenizers
self.model = pipe.transformer
self.pipeline = pipe
self.print_and_status_update("Model Loaded")
def get_generation_pipeline(self):
scheduler = ChromaRadianceModel.get_train_scheduler()
pipeline = ChromaPipeline(
scheduler=scheduler,
text_encoder=unwrap_model(self.text_encoder[0]),
tokenizer=self.tokenizer[0],
text_encoder_2=unwrap_model(self.text_encoder[1]),
tokenizer_2=self.tokenizer[1],
vae=unwrap_model(self.vae),
transformer=unwrap_model(self.transformer),
is_radiance=True,
)
# pipeline = pipeline.to(self.device_torch)
return pipeline
def generate_single_image(
self,
pipeline: ChromaPipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
extra['negative_prompt_embeds'] = unconditional_embeds.text_embeds
extra['negative_prompt_attn_mask'] = unconditional_embeds.attention_mask
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds,
prompt_attn_mask=conditional_embeds.attention_mask,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
generator=generator,
**extra
).images[0]
return img
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs
):
with torch.no_grad():
bs, c, h, w = latent_model_input.shape
img_ids = prepare_latent_image_ids(
bs, h, w, patch_size=16
).to(self.device_torch)
txt_ids = torch.zeros(
bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
guidance = torch.full([1], 0, device=self.device_torch, dtype=torch.float32)
guidance = guidance.expand(bs)
cast_dtype = self.unet.dtype
noise_pred = self.unet(
img=latent_model_input.to(
self.device_torch, cast_dtype
),
img_ids=img_ids,
txt=text_embeddings.text_embeds.to(
self.device_torch, cast_dtype
),
txt_ids=txt_ids,
txt_mask=text_embeddings.attention_mask.to(
self.device_torch, cast_dtype
),
timesteps=timestep / 1000,
guidance=guidance
)
if isinstance(noise_pred, QTensor):
noise_pred = noise_pred.dequantize()
return noise_pred
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
if isinstance(prompt, str):
prompts = [prompt]
else:
prompts = prompt
if self.pipeline.text_encoder.device != self.device_torch:
self.pipeline.text_encoder.to(self.device_torch)
max_length = 512
device = self.text_encoder[1].device
dtype = self.text_encoder[1].dtype
# T5
text_inputs = self.tokenizer[1](
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = self.text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0]
dtype = self.text_encoder[1].dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_attention_mask = text_inputs["attention_mask"]
pe = PromptEmbeds(
prompt_embeds
)
pe.attention_mask = prompt_attention_mask
return pe
def get_model_has_grad(self):
# return from a weight if it has grad
return False
def get_te_has_grad(self):
# return from a weight if it has grad
return False
def save_model(self, output_path, meta, save_dtype):
if not output_path.endswith(".safetensors"):
output_path = output_path + ".safetensors"
# only save the unet
transformer: Chroma = unwrap_model(self.model)
state_dict = transformer.state_dict()
save_dict = {}
for k, v in state_dict.items():
if isinstance(v, QTensor):
v = v.dequantize()
save_dict[k] = v.clone().to('cpu', dtype=save_dtype)
meta = get_meta_for_safetensors(meta, name='chroma')
save_file(save_dict, output_path, metadata=meta)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get('noise')
batch = kwargs.get('batch')
return (noise - batch.latents).detach()
def convert_lora_weights_before_save(self, state_dict):
# currently starte with transformer. but needs to start with diffusion_model. for comfyui
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("transformer.", "diffusion_model.")
new_sd[new_key] = value
return new_sd
def convert_lora_weights_before_load(self, state_dict):
# saved as diffusion_model. but needs to be transformer. for ai-toolkit
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("diffusion_model.", "transformer.")
new_sd[new_key] = value
return new_sd
def get_base_model_version(self):
return "chroma_radiance"

View File

@@ -6,6 +6,7 @@ from diffusers import FluxPipeline
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.utils import is_torch_xla_available
from diffusers.utils.torch_utils import randn_tensor
if is_torch_xla_available():
@@ -16,7 +17,134 @@ else:
XLA_AVAILABLE = False
def prepare_latent_image_ids(batch_size, height, width, patch_size=2, max_offset=0):
"""
Generates positional embeddings for a latent image.
Args:
batch_size (int): The number of images in the batch.
height (int): The height of the image.
width (int): The width of the image.
patch_size (int, optional): The size of the patches. Defaults to 2.
max_offset (int, optional): The maximum random offset to apply. Defaults to 0.
Returns:
torch.Tensor: A tensor containing the positional embeddings.
"""
# the random pos embedding helps generalize to larger res without training at large res
# pos embedding for rope, 2d pos embedding, corner embedding and not center based
latent_image_ids = torch.zeros(height // patch_size, width // patch_size, 3)
# Add positional encodings
latent_image_ids[..., 1] = (
latent_image_ids[..., 1] + torch.arange(height // patch_size)[:, None]
)
latent_image_ids[..., 2] = (
latent_image_ids[..., 2] + torch.arange(width // patch_size)[None, :]
)
# Add random offset if specified
if max_offset > 0:
offset_y = torch.randint(0, max_offset + 1, (1,)).item()
offset_x = torch.randint(0, max_offset + 1, (1,)).item()
latent_image_ids[..., 1] += offset_y
latent_image_ids[..., 2] += offset_x
(
latent_image_id_height,
latent_image_id_width,
latent_image_id_channels,
) = latent_image_ids.shape
# Reshape for batch
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size,
latent_image_id_height * latent_image_id_width,
latent_image_id_channels,
)
return latent_image_ids
class ChromaPipeline(FluxPipeline):
def __init__(
self,
scheduler,
vae,
text_encoder,
tokenizer,
text_encoder_2,
tokenizer_2,
transformer,
image_encoder = None,
feature_extractor = None,
is_radiance: bool = False,
):
super().__init__(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
transformer=transformer,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
self.is_radiance = is_radiance
self.vae_scale_factor = 8 if not is_radiance else 1
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = prepare_latent_image_ids(
batch_size,
height,
width,
patch_size=2 if not self.is_radiance else 16
).to(device=device, dtype=dtype)
# latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if not self.is_radiance:
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
# latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
latent_image_ids = prepare_latent_image_ids(
batch_size,
height,
width,
patch_size=2 if not self.is_radiance else 16
).to(device=device, dtype=dtype)
return latents, latent_image_ids
def __call__(
self,
prompt: Union[str, List[str]] = None,
@@ -70,6 +198,8 @@ class ChromaPipeline(FluxPipeline):
# 4. Prepare latent variables
num_channels_latents = 64 // 4
if self.is_radiance:
num_channels_latents = 3
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
@@ -82,8 +212,8 @@ class ChromaPipeline(FluxPipeline):
)
# extend img ids to match batch size
latent_image_ids = latent_image_ids.unsqueeze(0)
latent_image_ids = torch.cat([latent_image_ids] * batch_size, dim=0)
# latent_image_ids = latent_image_ids.unsqueeze(0)
# latent_image_ids = torch.cat([latent_image_ids] * batch_size, dim=0)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
@@ -180,8 +310,9 @@ class ChromaPipeline(FluxPipeline):
image = latents
else:
latents = self._unpack_latents(
latents, height, width, self.vae_scale_factor)
if not self.is_radiance:
latents = self._unpack_latents(
latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + \
self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]

View File

@@ -7,6 +7,7 @@ from torch import Tensor, nn
import torch.nn.functional as F
from .math import attention, rope
from functools import lru_cache
class EmbedND(nn.Module):
@@ -88,7 +89,7 @@ class RMSNorm(torch.nn.Module):
# return self._forward(x)
def distribute_modulations(tensor: torch.Tensor):
def distribute_modulations(tensor: torch.Tensor, depth_single_blocks, depth_double_blocks):
"""
Distributes slices of the tensor into the block_dict as ModulationOut objects.
@@ -102,25 +103,25 @@ def distribute_modulations(tensor: torch.Tensor):
# HARD CODED VALUES! lookup table for the generated vectors
# TODO: move this into chroma config!
# Add 38 single mod blocks
for i in range(38):
for i in range(depth_single_blocks):
key = f"single_blocks.{i}.modulation.lin"
block_dict[key] = None
# Add 19 image double blocks
for i in range(19):
for i in range(depth_double_blocks):
key = f"double_blocks.{i}.img_mod.lin"
block_dict[key] = None
# Add 19 text double blocks
for i in range(19):
for i in range(depth_double_blocks):
key = f"double_blocks.{i}.txt_mod.lin"
block_dict[key] = None
# Add the final layer
block_dict["final_layer.adaLN_modulation.1"] = None
# 6.2b version
block_dict["lite_double_blocks.4.img_mod.lin"] = None
block_dict["lite_double_blocks.4.txt_mod.lin"] = None
# block_dict["lite_double_blocks.4.img_mod.lin"] = None
# block_dict["lite_double_blocks.4.txt_mod.lin"] = None
idx = 0 # Index to keep track of the vector slices
@@ -173,6 +174,219 @@ def distribute_modulations(tensor: torch.Tensor):
return block_dict
class NerfEmbedder(nn.Module):
"""
An embedder module that combines input features with a 2D positional
encoding that mimics the Discrete Cosine Transform (DCT).
This module takes an input tensor of shape (B, P^2, C), where P is the
patch size, and enriches it with positional information before projecting
it to a new hidden size.
"""
def __init__(self, in_channels, hidden_size_input, max_freqs):
"""
Initializes the NerfEmbedder.
Args:
in_channels (int): The number of channels in the input tensor.
hidden_size_input (int): The desired dimension of the output embedding.
max_freqs (int): The number of frequency components to use for both
the x and y dimensions of the positional encoding.
The total number of positional features will be max_freqs^2.
"""
super().__init__()
self.max_freqs = max_freqs
self.hidden_size_input = hidden_size_input
# A linear layer to project the concatenated input features and
# positional encodings to the final output dimension.
self.embedder = nn.Sequential(
nn.Linear(in_channels + max_freqs**2, hidden_size_input)
)
@lru_cache(maxsize=4)
def fetch_pos(self, patch_size, device, dtype):
"""
Generates and caches 2D DCT-like positional embeddings for a given patch size.
The LRU cache is a performance optimization that avoids recomputing the
same positional grid on every forward pass.
Args:
patch_size (int): The side length of the square input patch.
device: The torch device to create the tensors on.
dtype: The torch dtype for the tensors.
Returns:
A tensor of shape (1, patch_size^2, max_freqs^2) containing the
positional embeddings.
"""
# Create normalized 1D coordinate grids from 0 to 1.
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
# Create a 2D meshgrid of coordinates.
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
# Reshape positions to be broadcastable with frequencies.
# Shape becomes (patch_size^2, 1, 1).
pos_x = pos_x.reshape(-1, 1, 1)
pos_y = pos_y.reshape(-1, 1, 1)
# Create a 1D tensor of frequency values from 0 to max_freqs-1.
freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
# Reshape frequencies to be broadcastable for creating 2D basis functions.
# freqs_x shape: (1, max_freqs, 1)
# freqs_y shape: (1, 1, max_freqs)
freqs_x = freqs[None, :, None]
freqs_y = freqs[None, None, :]
# A custom weighting coefficient, not part of standard DCT.
# This seems to down-weight the contribution of higher-frequency interactions.
coeffs = (1 + freqs_x * freqs_y) ** -1
# Calculate the 1D cosine basis functions for x and y coordinates.
# This is the core of the DCT formulation.
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
# Combine the 1D basis functions to create 2D basis functions by element-wise
# multiplication, and apply the custom coefficients. Broadcasting handles the
# combination of all (pos_x, freqs_x) with all (pos_y, freqs_y).
# The result is flattened into a feature vector for each position.
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
return dct
def forward(self, inputs):
"""
Forward pass for the embedder.
Args:
inputs (Tensor): The input tensor of shape (B, P^2, C).
Returns:
Tensor: The output tensor of shape (B, P^2, hidden_size_input).
"""
# Get the batch size, number of pixels, and number of channels.
B, P2, C = inputs.shape
# Store the original dtype to cast back to at the end.
original_dtype = inputs.dtype
# Force all operations within this module to run in fp32.
with torch.autocast("cuda", enabled=False):
# Infer the patch side length from the number of pixels (P^2).
patch_size = int(P2 ** 0.5)
inputs = inputs.float()
# Fetch the pre-computed or cached positional embeddings.
dct = self.fetch_pos(patch_size, inputs.device, torch.float32)
# Repeat the positional embeddings for each item in the batch.
dct = dct.repeat(B, 1, 1)
# Concatenate the original input features with the positional embeddings
# along the feature dimension.
inputs = torch.cat([inputs, dct], dim=-1)
# Project the combined tensor to the target hidden size.
inputs = self.embedder.float()(inputs)
return inputs.to(original_dtype)
class NerfGLUBlock(nn.Module):
"""
A NerfBlock using a Gated Linear Unit (GLU) like MLP.
"""
def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio, use_compiled):
super().__init__()
# The total number of parameters for the MLP is increased to accommodate
# the gate, value, and output projection matrices.
# We now need to generate parameters for 3 matrices.
total_params = 3 * hidden_size_x**2 * mlp_ratio
self.param_generator = nn.Linear(hidden_size_s, total_params)
self.norm = RMSNorm(hidden_size_x, use_compiled)
self.mlp_ratio = mlp_ratio
# nn.init.zeros_(self.param_generator.weight)
# nn.init.zeros_(self.param_generator.bias)
def forward(self, x, s):
batch_size, num_x, hidden_size_x = x.shape
mlp_params = self.param_generator(s)
# Split the generated parameters into three parts for the gate, value, and output projection.
fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1)
# Reshape the parameters into matrices for batch matrix multiplication.
fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x)
# Normalize the generated weight matrices as in the original implementation.
fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2)
fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2)
fc2 = torch.nn.functional.normalize(fc2, dim=-2)
res_x = x
x = self.norm(x)
# Apply the final output projection.
x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
x = x + res_x
return x
class NerfFinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, use_compiled):
super().__init__()
self.norm = RMSNorm(hidden_size, use_compiled=use_compiled)
self.linear = nn.Linear(hidden_size, out_channels)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x):
x = self.norm(x)
x = self.linear(x)
return x
class NerfFinalLayerConv(nn.Module):
def __init__(self, hidden_size, out_channels, use_compiled):
super().__init__()
self.norm = RMSNorm(hidden_size, use_compiled=use_compiled)
# replace nn.Linear with nn.Conv2d since linear is just pointwise conv
self.conv = nn.Conv2d(
in_channels=hidden_size,
out_channels=out_channels,
kernel_size=3,
padding=1
)
nn.init.zeros_(self.conv.weight)
nn.init.zeros_(self.conv.bias)
def forward(self, x):
# shape: [N, C, H, W] !
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
# So, we permute the dimensions to make the channel dimension the last one.
x_permuted = x.permute(0, 2, 3, 1) # Shape becomes [N, H, W, C]
# Apply normalization on the feature/channel dimension
x_norm = self.norm(x_permuted)
# Permute back to the original dimension order for the convolution
x_norm_permuted = x_norm.permute(0, 3, 1, 2) # Shape becomes [N, C, H, W]
# Apply the 3x3 convolution
x = self.conv(x_norm_permuted)
return x
class Approximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=4):
super().__init__()

View File

@@ -156,13 +156,19 @@ class Chroma(nn.Module):
)
# TODO: move this hardcoded value to config
self.mod_index_length = 344
# single layer has 3 modulation vectors
# double layer has 6 modulation vectors for each expert
# final layer has 2 modulation vectors
self.mod_index_length = 3 * params.depth_single_blocks + 2 * 6 * params.depth + 2
self.depth_single_blocks = params.depth_single_blocks
self.depth_double_blocks = params.depth
# self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0)
self.register_buffer(
"mod_index",
torch.tensor(list(range(self.mod_index_length)), device="cpu"),
persistent=False,
)
self.approximator_in_dim = params.approximator_in_dim
@property
def device(self):
@@ -213,7 +219,7 @@ class Chroma(nn.Module):
# then and only then we could concatenate it together
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1)
mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True))
mod_vectors_dict = distribute_modulations(mod_vectors)
mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)

View File

@@ -0,0 +1,380 @@
from dataclasses import dataclass
import torch
from torch import Tensor, nn
import torch.utils.checkpoint as ckpt
from .layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
SingleStreamBlock,
timestep_embedding,
Approximator,
distribute_modulations,
NerfEmbedder,
NerfFinalLayer,
NerfFinalLayerConv,
NerfGLUBlock
)
@dataclass
class ChromaParams:
in_channels: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
approximator_in_dim: int
approximator_depth: int
approximator_hidden_size: int
patch_size: int
nerf_hidden_size: int
nerf_mlp_ratio: int
nerf_depth: int
nerf_max_freqs: int
_use_compiled: bool
chroma_params = ChromaParams(
in_channels=3,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
approximator_in_dim=64,
approximator_depth=5,
approximator_hidden_size=5120,
patch_size=16,
nerf_hidden_size=64,
nerf_mlp_ratio=4,
nerf_depth=4,
nerf_max_freqs=8,
_use_compiled=False,
)
def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8):
"""
Modifies attention mask to allow attention to a few extra padding tokens.
Args:
mask: Original attention mask (1 for tokens to attend to, 0 for masked tokens)
max_seq_length: Maximum sequence length of the model
num_extra_padding: Number of padding tokens to unmask
Returns:
Modified mask
"""
# Get the actual sequence length from the mask
seq_length = mask.sum(dim=-1)
batch_size = mask.shape[0]
modified_mask = mask.clone()
for i in range(batch_size):
current_seq_len = int(seq_length[i].item())
# Only add extra padding tokens if there's room
if current_seq_len < max_seq_length:
# Calculate how many padding tokens we can unmask
available_padding = max_seq_length - current_seq_len
tokens_to_unmask = min(num_extra_padding, available_padding)
# Unmask the specified number of padding tokens right after the sequence
modified_mask[i, current_seq_len : current_seq_len + tokens_to_unmask] = 1
return modified_mask
class Chroma(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: ChromaParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
self.gradient_checkpointing = False
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
)
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
)
# self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
# patchify ops
self.img_in_patch = nn.Conv2d(
params.in_channels,
params.hidden_size,
kernel_size=params.patch_size,
stride=params.patch_size,
bias=True
)
nn.init.zeros_(self.img_in_patch.weight)
nn.init.zeros_(self.img_in_patch.bias)
# TODO: need proper mapping for this approximator output!
# currently the mapping is hardcoded in distribute_modulations function
self.distilled_guidance_layer = Approximator(
params.approximator_in_dim,
self.hidden_size,
params.approximator_hidden_size,
params.approximator_depth,
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
use_compiled=params._use_compiled,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
use_compiled=params._use_compiled,
)
for _ in range(params.depth_single_blocks)
]
)
# self.final_layer = LastLayer(
# self.hidden_size,
# 1,
# self.out_channels,
# use_compiled=params._use_compiled,
# )
# pixel channel concat with DCT
self.nerf_image_embedder = NerfEmbedder(
in_channels=params.in_channels,
hidden_size_input=params.nerf_hidden_size,
max_freqs=params.nerf_max_freqs
)
self.nerf_blocks = nn.ModuleList([
NerfGLUBlock(
hidden_size_s=params.hidden_size,
hidden_size_x=params.nerf_hidden_size,
mlp_ratio=params.nerf_mlp_ratio,
use_compiled=params._use_compiled
) for _ in range(params.nerf_depth)
])
# self.nerf_final_layer = NerfFinalLayer(
# params.nerf_hidden_size,
# out_channels=params.in_channels,
# use_compiled=params._use_compiled
# )
self.nerf_final_layer_conv = NerfFinalLayerConv(
params.nerf_hidden_size,
out_channels=params.in_channels,
use_compiled=params._use_compiled
)
# TODO: move this hardcoded value to config
# single layer has 3 modulation vectors
# double layer has 6 modulation vectors for each expert
# final layer has 2 modulation vectors
self.mod_index_length = 3 * params.depth_single_blocks + 2 * 6 * params.depth + 2
self.depth_single_blocks = params.depth_single_blocks
self.depth_double_blocks = params.depth
# self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0)
self.register_buffer(
"mod_index",
torch.tensor(list(range(self.mod_index_length)), device="cpu"),
persistent=False,
)
self.approximator_in_dim = params.approximator_in_dim
@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device
def enable_gradient_checkpointing(self, enable: bool = True):
self.gradient_checkpointing = enable
def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
guidance: Tensor,
attn_padding: int = 1,
) -> Tensor:
if img.ndim != 4:
raise ValueError("Input img tensor must be in [B, C, H, W] format.")
if txt.ndim != 3:
raise ValueError("Input txt tensors must have 3 dimensions.")
B, C, H, W = img.shape
# gemini gogogo idk how to unfold and pack the patch properly :P
# Store the raw pixel values of each patch for the NeRF head later.
# unfold creates patches: [B, C * P * P, NumPatches]
nerf_pixels = nn.functional.unfold(img, kernel_size=self.params.patch_size, stride=self.params.patch_size)
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
# partchify ops
img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P]
num_patches = img.shape[2] * img.shape[3]
# flatten into a sequence for the transformer.
img = img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden]
txt = self.txt_in(txt)
# TODO:
# need to fix grad accumulation issue here for now it's in no grad mode
# besides, i don't want to wash out the PFP that's trained on this model weights anyway
# the fan out operation here is deleting the backward graph
# alternatively doing forward pass for every block manually is doable but slow
# custom backward probably be better
with torch.no_grad():
distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim//4)
# TODO: need to add toggle to omit this from schnell but that's not a priority
distil_guidance = timestep_embedding(guidance, self.approximator_in_dim//4)
# get all modulation index
modulation_index = timestep_embedding(self.mod_index, self.approximator_in_dim//2)
# we need to broadcast the modulation index here so each batch has all of the index
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1)
# and we need to broadcast timestep and guidance along too
timestep_guidance = (
torch.cat([distill_timestep, distil_guidance], dim=1)
.unsqueeze(1)
.repeat(1, self.mod_index_length, 1)
)
# then and only then we could concatenate it together
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1)
mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True))
mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
# compute mask
# assume max seq length from the batched input
max_len = txt.shape[1]
# mask
with torch.no_grad():
txt_mask_w_padding = modify_mask_to_attend_padding(
txt_mask, max_len, attn_padding
)
txt_img_mask = torch.cat(
[
txt_mask_w_padding,
torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device),
],
dim=1,
)
txt_img_mask = txt_img_mask.float().T @ txt_img_mask.float()
txt_img_mask = (
txt_img_mask[None, None, ...]
.repeat(txt.shape[0], self.num_heads, 1, 1)
.int()
.bool()
)
# txt_mask_w_padding[txt_mask_w_padding==False] = True
for i, block in enumerate(self.double_blocks):
# the guidance replaced by FFN output
img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"]
txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"]
double_mod = [img_mod, txt_mod]
# just in case in different GPU for simple pipeline parallel
if torch.is_grad_enabled() and self.gradient_checkpointing:
img.requires_grad_(True)
img, txt = ckpt.checkpoint(
block, img, txt, pe, double_mod, txt_img_mask
)
else:
img, txt = block(
img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask
)
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
if torch.is_grad_enabled() and self.gradient_checkpointing:
img.requires_grad_(True)
img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask)
else:
img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask)
img = img[:, txt.shape[1] :, ...]
# final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"]
# img = self.final_layer(
# img, distill_vec=final_mod
# ) # (N, T, patch_size ** 2 * out_channels)
# aliasing
nerf_hidden = img
# reshape for per-patch processing
nerf_hidden = nerf_hidden.reshape(B * num_patches, self.params.hidden_size)
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, self.params.patch_size**2).transpose(1, 2)
# get DCT-encoded pixel embeddings [pixel-dct]
img_dct = self.nerf_image_embedder(nerf_pixels)
# pass through the dynamic MLP blocks (the NeRF)
for i, block in enumerate(self.nerf_blocks):
if self.training:
img_dct = ckpt.checkpoint(block, img_dct, nerf_hidden)
else:
img_dct = block(img_dct, nerf_hidden)
# final projection to get the output pixel values
# img_dct = self.nerf_final_layer(img_dct) # -> [B*NumPatches, P*P, C]
img_dct = self.nerf_final_layer_conv.norm(img_dct)
# gemini gogogo idk how to fold this properly :P
# Reassemble the patches into the final image.
img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P]
# Reshape to combine with batch dimension for fold
img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P]
img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches]
img_dct = nn.functional.fold(
img_dct,
output_size=(H, W),
kernel_size=self.params.patch_size,
stride=self.params.patch_size
) # [B, Hidden, H, W]
img_dct = self.nerf_final_layer_conv.conv(img_dct)
return img_dct

134
toolkit/models/FakeVAE.py Normal file
View File

@@ -0,0 +1,134 @@
from diffusers import AutoencoderKL
from typing import Optional, Union
import torch
import torch.nn as nn
import numpy as np
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKLOutput
from diffusers.models.autoencoders.vae import DecoderOutput
class Config:
in_channels = 3
out_channels = 3
down_block_types = ("1",)
up_block_types = ("1",)
block_out_channels = (1,)
latent_channels = 3 # usually 4
norm_num_groups = 1
sample_size = 512
scaling_factor = 1.0
# scaling_factor = 1.8
shift_factor = 0
def __getitem__(cls, x):
return getattr(cls, x)
class FakeVAE(nn.Module):
def __init__(self):
super().__init__()
self._dtype = torch.float32
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.config = Config()
@property
def dtype(self):
return self._dtype
@dtype.setter
def dtype(self, value):
self._dtype = value
@property
def device(self):
return self._device
@device.setter
def device(self, value):
self._device = value
# mimic to from torch
def to(self, *args, **kwargs):
# pull out dtype and device if they exist
if "dtype" in kwargs:
self._dtype = kwargs["dtype"]
if "device" in kwargs:
self._device = kwargs["device"]
return super().to(*args, **kwargs)
def enable_xformers_memory_efficient_attention(self):
pass
# @apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> AutoencoderKLOutput:
h = x
# moments = self.quant_conv(h)
# posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (h,)
class FakeDist:
def __init__(self, x):
self._sample = x
def sample(self):
return self._sample
return AutoencoderKLOutput(latent_dist=FakeDist(h))
def _decode(
self, z: torch.FloatTensor, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
dec = z
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
# @apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def _set_gradient_checkpointing(self, module, value=False):
pass
def enable_tiling(self, use_tiling: bool = True):
pass
def disable_tiling(self):
pass
def enable_slicing(self):
pass
def disable_slicing(self):
pass
def set_use_memory_efficient_attention_xformers(self, value: bool = True):
pass
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
dec = sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)