Added support for Lodestone Rock's Chroma model

This commit is contained in:
Jaret Burkett
2025-04-05 13:21:36 -06:00
parent 2b901cca39
commit 7c21eac1b3
9 changed files with 1499 additions and 0 deletions

View File

@@ -0,0 +1,97 @@
---
job: extension
config:
# this name will be the folder and filename name
name: "my_first_chroma_lora_v1"
process:
- type: 'sd_trainer'
# root folder to save training sessions/samples/weights
training_folder: "output"
# uncomment to see performance stats in the terminal every N steps
# performance_log_every: 1000
device: cuda:0
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
# trigger_word: "p3r5on"
network:
type: "lora"
linear: 16
linear_alpha: 16
save:
dtype: float16 # precision to save
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
# datasets are a folder of images. captions need to be txt files with the same name as the image
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
# images will automatically be resized and bucketed into the resolution specified
# on windows, escape back slashes with another backslash so
# "C:\\path\\to\\images\\folder"
- folder_path: "/path/to/images/folder"
caption_ext: "txt"
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
shuffle_tokens: false # shuffle caption order, split by commas
cache_latents_to_disk: true # leave this true unless you know what you're doing
resolution: [ 512, 768, 1024 ] # chroma enjoys multiple resolutions
train:
batch_size: 1
steps: 2000 # total number of steps to train 500 - 4000 is a good range
gradient_accumulation: 1
train_unet: true
train_text_encoder: false # probably won't work with chroma
gradient_checkpointing: true # need the on unless you have a ton of vram
noise_scheduler: "flowmatch" # for training only
optimizer: "adamw8bit"
lr: 1e-4
# uncomment this to skip the pre training sample
# skip_first_sample: true
# uncomment to completely disable sampling
# disable_sampling: true
# uncomment to use new vell curved weighting. Experimental but may produce better results
# linear_timesteps: true
# ema will smooth out learning, but could slow it down. Recommended to leave on.
ema_config:
use_ema: true
ema_decay: 0.99
# will probably need this if gpu supports it for chroma, other dtypes may not work correctly
dtype: bf16
model:
# Download the whichever model you prefer from the Chroma repo
# https://huggingface.co/lodestones/Chroma/tree/main
# point to it here.
name_or_path: "/path/to/chroma/chroma-unlocked-vVERSION.safetensors"
arch: "chroma"
quantize: true # run 8bit mixed precision
sample:
sampler: "flowmatch" # must match train.noise_scheduler
sample_every: 250 # sample every this many steps
width: 1024
height: 1024
prompts:
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
- "woman with red hair, playing chess at the park, bomb going off in the background"
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
- "a bear building a log cabin in the snow covered mountains"
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
- "hipster man with a beard, building a chair, in a wood shop"
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
- "a man holding a sign that says, 'this is a sign'"
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
neg: "" # negative prompt, optional
seed: 42
walk_seed: true
guidance_scale: 4
sample_steps: 25
# you can add any additional meta info here. [name] is replaced with config name at top
meta:
name: "[name]"
version: '1.0'

View File

@@ -0,0 +1,6 @@
from .chroma import ChromaModel
AI_TOOLKIT_MODELS = [
# put a list of models here
ChromaModel
]

View File

@@ -0,0 +1 @@
from .chroma_model import ChromaModel

View File

@@ -0,0 +1,388 @@
import os
from typing import TYPE_CHECKING
import torch
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from PIL import Image
from toolkit.models.base_model import BaseModel
from toolkit.basic import flush
from diffusers import AutoencoderKL
# from toolkit.pixel_shuffle_encoder import AutoencoderPixelMixer
from toolkit.prompt_utils import PromptEmbeds
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
from toolkit.dequantize import patch_dequantization_on_save
from toolkit.accelerator import unwrap_model
from optimum.quanto import freeze, QTensor
from toolkit.util.quantize import quantize, get_qtype
from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer
from .pipeline import ChromaPipeline
from einops import rearrange, repeat
import random
import torch.nn.functional as F
from .src.model import Chroma, chroma_params
from safetensors.torch import load_file, save_file
from toolkit.metadata import get_meta_for_safetensors
if TYPE_CHECKING:
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": 0.5,
"max_image_seq_len": 4096,
"max_shift": 1.15,
"num_train_timesteps": 1000,
"shift": 3.0,
"use_dynamic_shifting": True
}
class FakeConfig:
# for diffusers compatability
def __init__(self):
self.attention_head_dim = 128
self.guidance_embeds = True
self.in_channels = 64
self.joint_attention_dim = 4096
self.num_attention_heads = 24
self.num_layers = 19
self.num_single_layers = 38
self.patch_size = 1
class FakeCLIP(torch.nn.Module):
def __init__(self):
super().__init__()
self.dtype = torch.bfloat16
self.device = 'cuda'
self.text_model = None
self.tokenizer = None
self.model_max_length = 77
def forward(self, *args, **kwargs):
return torch.zeros(1, 1, 1).to(self.device)
class ChromaModel(BaseModel):
arch = "chroma"
def __init__(
self,
device,
model_config: ModelConfig,
dtype='bf16',
custom_pipeline=None,
noise_scheduler=None,
**kwargs
):
super().__init__(
device,
model_config,
dtype,
custom_pipeline,
noise_scheduler,
**kwargs
)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ['Chroma']
# static method to get the noise scheduler
@staticmethod
def get_train_scheduler():
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
def get_bucket_divisibility(self):
# return the bucket divisibility for the model
return 32
def load_model(self):
dtype = self.torch_dtype
# will be updated if we detect a existing checkpoint in training folder
model_path = self.model_config.name_or_path
extras_path = 'black-forest-labs/FLUX.1-schnell'
self.print_and_status_update("Loading transformer")
transformer = Chroma(chroma_params)
# add dtype, not sure why it doesnt have it
transformer.dtype = dtype
chroma_state_dict = load_file(model_path, 'cpu')
# load the state dict into the model
transformer.load_state_dict(chroma_state_dict)
transformer.to(self.quantize_device, dtype=dtype)
transformer.config = FakeConfig()
if self.model_config.quantize:
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = get_qtype(self.model_config.qtype)
self.print_and_status_update("Quantizing transformer")
quantize(transformer, weights=quantization_type,
**self.model_config.quantize_kwargs)
freeze(transformer)
transformer.to(self.device_torch)
else:
transformer.to(self.device_torch, dtype=dtype)
flush()
self.print_and_status_update("Loading T5")
tokenizer_2 = T5TokenizerFast.from_pretrained(
extras_path, subfolder="tokenizer_2", torch_dtype=dtype
)
text_encoder_2 = T5EncoderModel.from_pretrained(
extras_path, subfolder="text_encoder_2", torch_dtype=dtype
)
text_encoder_2.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing T5")
quantize(text_encoder_2, weights=get_qtype(
self.model_config.qtype))
freeze(text_encoder_2)
flush()
# self.print_and_status_update("Loading CLIP")
text_encoder = FakeCLIP()
tokenizer = FakeCLIP()
text_encoder.to(self.device_torch, dtype=dtype)
self.noise_scheduler = ChromaModel.get_train_scheduler()
self.print_and_status_update("Loading VAE")
vae = AutoencoderKL.from_pretrained(
extras_path,
subfolder="vae",
torch_dtype=dtype
)
vae = vae.to(self.device_torch, dtype=dtype)
self.print_and_status_update("Making pipe")
pipe: ChromaPipeline = ChromaPipeline(
scheduler=self.noise_scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=None,
tokenizer_2=tokenizer_2,
vae=vae,
transformer=None,
)
# for quantization, it works best to do these after making the pipe
pipe.text_encoder_2 = text_encoder_2
pipe.transformer = transformer
self.print_and_status_update("Preparing Model")
text_encoder = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
# just to make sure everything is on the right device and dtype
text_encoder[0].to(self.device_torch)
text_encoder[0].requires_grad_(False)
text_encoder[0].eval()
text_encoder[1].to(self.device_torch)
text_encoder[1].requires_grad_(False)
text_encoder[1].eval()
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
# save it to the model class
self.vae = vae
self.text_encoder = text_encoder # list of text encoders
self.tokenizer = tokenizer # list of tokenizers
self.model = pipe.transformer
self.pipeline = pipe
self.print_and_status_update("Model Loaded")
def get_generation_pipeline(self):
scheduler = ChromaModel.get_train_scheduler()
pipeline = ChromaPipeline(
scheduler=scheduler,
text_encoder=unwrap_model(self.text_encoder[0]),
tokenizer=self.tokenizer[0],
text_encoder_2=unwrap_model(self.text_encoder[1]),
tokenizer_2=self.tokenizer[1],
vae=unwrap_model(self.vae),
transformer=unwrap_model(self.transformer)
)
# pipeline = pipeline.to(self.device_torch)
return pipeline
def generate_single_image(
self,
pipeline: ChromaPipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
extra['negative_prompt_embeds'] = unconditional_embeds.text_embeds
extra['negative_prompt_attn_mask'] = unconditional_embeds.attention_mask
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds,
prompt_attn_mask=conditional_embeds.attention_mask,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
generator=generator,
**extra
).images[0]
return img
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs
):
with torch.no_grad():
bs, c, h, w = latent_model_input.shape
latent_model_input_packed = rearrange(
latent_model_input,
"b c (h ph) (w pw) -> b (h w) (c ph pw)",
ph=2,
pw=2
)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c",
b=bs).to(self.device_torch)
txt_ids = torch.zeros(
bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
guidance = torch.full([1], 0, device=self.device_torch, dtype=torch.float32)
guidance = guidance.expand(latent_model_input_packed.shape[0])
cast_dtype = self.unet.dtype
noise_pred = self.unet(
img=latent_model_input_packed.to(
self.device_torch, cast_dtype
),
img_ids=img_ids,
txt=text_embeddings.text_embeds.to(
self.device_torch, cast_dtype
),
txt_ids=txt_ids,
txt_mask=text_embeddings.attention_mask.to(
self.device_torch, cast_dtype
),
timesteps=timestep / 1000,
guidance=guidance
)
if isinstance(noise_pred, QTensor):
noise_pred = noise_pred.dequantize()
noise_pred = rearrange(
noise_pred,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=latent_model_input.shape[2] // 2,
w=latent_model_input.shape[3] // 2,
ph=2,
pw=2,
c=self.vae.config.latent_channels
)
return noise_pred
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
if isinstance(prompt, str):
prompts = [prompt]
else:
prompts = prompt
if self.pipeline.text_encoder.device != self.device_torch:
self.pipeline.text_encoder.to(self.device_torch)
max_length = 512
device = self.text_encoder[1].device
dtype = self.text_encoder[1].dtype
# T5
text_inputs = self.tokenizer[1](
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = self.text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0]
dtype = self.text_encoder[1].dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_attention_mask = text_inputs["attention_mask"]
pe = PromptEmbeds(
prompt_embeds
)
pe.attention_mask = prompt_attention_mask
return pe
def get_model_has_grad(self):
# return from a weight if it has grad
return self.model.final_layer.linear.weight.requires_grad
def get_te_has_grad(self):
# return from a weight if it has grad
return self.text_encoder[1].encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
def save_model(self, output_path, meta, save_dtype):
# only save the unet
transformer: Chroma = unwrap_model(self.model)
state_dict = transformer.state_dict()
save_dict = {}
for k, v in state_dict.items():
if isinstance(v, QTensor):
v = v.dequantize()
save_dict[k] = v.clone().to('cpu', dtype=save_dtype)
meta = get_meta_for_safetensors(meta, name='chroma')
save_file(save_dict, output_path, metadata=meta)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get('noise')
batch = kwargs.get('batch')
return (noise - batch.latents).detach()
def convert_lora_weights_before_save(self, state_dict):
# currently starte with transformer. but needs to start with diffusion_model. for comfyui
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("transformer.", "diffusion_model.")
new_sd[new_key] = value
return new_sd
def convert_lora_weights_before_load(self, state_dict):
# saved as diffusion_model. but needs to be transformer. for ai-toolkit
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("diffusion_model.", "transformer.")
new_sd[new_key] = value
return new_sd

View File

@@ -0,0 +1,195 @@
from typing import Union, List, Optional, Dict, Any, Callable
import numpy as np
import torch
from diffusers import FluxPipeline
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
class ChromaPipeline(FluxPipeline):
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator,
List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attn_mask: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_attn_mask: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[
int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=torch.bfloat16)
if guidance_scale > 1.00001:
negative_text_ids = torch.zeros(batch_size, negative_prompt_embeds.shape[1], 3).to(device=device, dtype=torch.bfloat16)
# 4. Prepare latent variables
num_channels_latents = 64 // 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# extend img ids to match batch size
latent_image_ids = latent_image_ids.unsqueeze(0)
latent_image_ids = torch.cat([latent_image_ids] * batch_size, dim=0)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
num_warmup_steps = max(
len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
guidance = torch.full([1], 0, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
# handle guidance
noise_pred_text = self.transformer(
img=latents,
img_ids=latent_image_ids,
txt=prompt_embeds,
txt_ids=text_ids,
txt_mask=prompt_attn_mask, # todo add this
timesteps=timestep / 1000,
guidance=guidance
)
if guidance_scale > 1.00001:
noise_pred_uncond = self.transformer(
img=latents,
img_ids=latent_image_ids,
txt=negative_prompt_embeds,
txt_ids=negative_text_ids,
txt_mask=negative_prompt_attn_mask, # todo add this
timesteps=timestep / 1000,
guidance=guidance
)
noise_pred = noise_pred_uncond + self.guidance_scale * \
(noise_pred_text - noise_pred_uncond)
else:
noise_pred = noise_pred_text
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(
noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(
self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop(
"prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(
latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + \
self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(
image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)

View File

@@ -0,0 +1 @@
# This is taken and slightly modified from https://github.com/lodestone-rock/flow/tree/master/src/models/chroma

View File

@@ -0,0 +1,505 @@
import math
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
import torch.nn.functional as F
from .math import attention, rope
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, use_compiled: bool = False):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
self.use_compiled = use_compiled
def _forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * self.scale
def forward(self, x: Tensor):
return F.rms_norm(x, self.scale.shape, weight=self.scale, eps=1e-6)
# if self.use_compiled:
# return torch.compile(self._forward)(x)
# else:
# return self._forward(x)
def distribute_modulations(tensor: torch.Tensor):
"""
Distributes slices of the tensor into the block_dict as ModulationOut objects.
Args:
tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim].
"""
batch_size, vectors, dim = tensor.shape
block_dict = {}
# HARD CODED VALUES! lookup table for the generated vectors
# TODO: move this into chroma config!
# Add 38 single mod blocks
for i in range(38):
key = f"single_blocks.{i}.modulation.lin"
block_dict[key] = None
# Add 19 image double blocks
for i in range(19):
key = f"double_blocks.{i}.img_mod.lin"
block_dict[key] = None
# Add 19 text double blocks
for i in range(19):
key = f"double_blocks.{i}.txt_mod.lin"
block_dict[key] = None
# Add the final layer
block_dict["final_layer.adaLN_modulation.1"] = None
# 6.2b version
block_dict["lite_double_blocks.4.img_mod.lin"] = None
block_dict["lite_double_blocks.4.txt_mod.lin"] = None
idx = 0 # Index to keep track of the vector slices
for key in block_dict.keys():
if "single_blocks" in key:
# Single block: 1 ModulationOut
block_dict[key] = ModulationOut(
shift=tensor[:, idx : idx + 1, :],
scale=tensor[:, idx + 1 : idx + 2, :],
gate=tensor[:, idx + 2 : idx + 3, :],
)
idx += 3 # Advance by 3 vectors
elif "img_mod" in key:
# Double block: List of 2 ModulationOut
double_block = []
for _ in range(2): # Create 2 ModulationOut objects
double_block.append(
ModulationOut(
shift=tensor[:, idx : idx + 1, :],
scale=tensor[:, idx + 1 : idx + 2, :],
gate=tensor[:, idx + 2 : idx + 3, :],
)
)
idx += 3 # Advance by 3 vectors per ModulationOut
block_dict[key] = double_block
elif "txt_mod" in key:
# Double block: List of 2 ModulationOut
double_block = []
for _ in range(2): # Create 2 ModulationOut objects
double_block.append(
ModulationOut(
shift=tensor[:, idx : idx + 1, :],
scale=tensor[:, idx + 1 : idx + 2, :],
gate=tensor[:, idx + 2 : idx + 3, :],
)
)
idx += 3 # Advance by 3 vectors per ModulationOut
block_dict[key] = double_block
elif "final_layer" in key:
# Final layer: 1 ModulationOut
block_dict[key] = [
tensor[:, idx : idx + 1, :],
tensor[:, idx + 1 : idx + 2, :],
]
idx += 2 # Advance by 3 vectors
return block_dict
class Approximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=4):
super().__init__()
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
self.layers = nn.ModuleList(
[MLPEmbedder(hidden_dim, hidden_dim) for x in range(n_layers)]
)
self.norms = nn.ModuleList([RMSNorm(hidden_dim) for x in range(n_layers)])
self.out_proj = nn.Linear(hidden_dim, out_dim)
@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device
def forward(self, x: Tensor) -> Tensor:
x = self.in_proj(x)
for layer, norms in zip(self.layers, self.norms):
x = x + layer(norms(x))
x = self.out_proj(x)
return x
class QKNorm(torch.nn.Module):
def __init__(self, dim: int, use_compiled: bool = False):
super().__init__()
self.query_norm = RMSNorm(dim, use_compiled=use_compiled)
self.key_norm = RMSNorm(dim, use_compiled=use_compiled)
self.use_compiled = use_compiled
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
class SelfAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
use_compiled: bool = False,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim, use_compiled=use_compiled)
self.proj = nn.Linear(dim, dim)
self.use_compiled = use_compiled
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
def _modulation_shift_scale_fn(x, scale, shift):
return (1 + scale) * x + shift
def _modulation_gate_fn(x, gate, gate_params):
return x + gate * gate_params
class DoubleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float,
qkv_bias: bool = False,
use_compiled: bool = False,
):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_compiled=use_compiled,
)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_compiled=use_compiled,
)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.use_compiled = use_compiled
@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device
def modulation_shift_scale_fn(self, x, scale, shift):
if self.use_compiled:
return torch.compile(_modulation_shift_scale_fn)(x, scale, shift)
else:
return _modulation_shift_scale_fn(x, scale, shift)
def modulation_gate_fn(self, x, gate, gate_params):
if self.use_compiled:
return torch.compile(_modulation_gate_fn)(x, gate, gate_params)
else:
return _modulation_gate_fn(x, gate, gate_params)
def forward(
self,
img: Tensor,
txt: Tensor,
pe: Tensor,
distill_vec: list[ModulationOut],
mask: Tensor,
) -> tuple[Tensor, Tensor]:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = distill_vec
# prepare image for attention
img_modulated = self.img_norm1(img)
# replaced with compiled fn
# img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_modulated = self.modulation_shift_scale_fn(
img_modulated, img_mod1.scale, img_mod1.shift
)
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(
img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
# replaced with compiled fn
# txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_modulated = self.modulation_shift_scale_fn(
txt_modulated, txt_mod1.scale, txt_mod1.shift
)
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(
txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe, mask=mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
# replaced with compiled fn
# img = img + img_mod1.gate * self.img_attn.proj(img_attn)
# img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
img = self.modulation_gate_fn(img, img_mod1.gate, self.img_attn.proj(img_attn))
img = self.modulation_gate_fn(
img,
img_mod2.gate,
self.img_mlp(
self.modulation_shift_scale_fn(
self.img_norm2(img), img_mod2.scale, img_mod2.shift
)
),
)
# calculate the txt bloks
# replaced with compiled fn
# txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
# txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
txt = self.modulation_gate_fn(txt, txt_mod1.gate, self.txt_attn.proj(txt_attn))
txt = self.modulation_gate_fn(
txt,
txt_mod2.gate,
self.txt_mlp(
self.modulation_shift_scale_fn(
self.txt_norm2(txt), txt_mod2.scale, txt_mod2.shift
)
),
)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float | None = None,
use_compiled: bool = False,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim, use_compiled=use_compiled)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.use_compiled = use_compiled
@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device
def modulation_shift_scale_fn(self, x, scale, shift):
if self.use_compiled:
return torch.compile(_modulation_shift_scale_fn)(x, scale, shift)
else:
return _modulation_shift_scale_fn(x, scale, shift)
def modulation_gate_fn(self, x, gate, gate_params):
if self.use_compiled:
return torch.compile(_modulation_gate_fn)(x, gate, gate_params)
else:
return _modulation_gate_fn(x, gate, gate_params)
def forward(
self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], mask: Tensor
) -> Tensor:
mod = distill_vec
# replaced with compiled fn
# x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
x_mod = self.modulation_shift_scale_fn(self.pre_norm(x), mod.scale, mod.shift)
qkv, mlp = torch.split(
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
# replaced with compiled fn
# return x + mod.gate * output
return self.modulation_gate_fn(x, mod.gate, output)
class LastLayer(nn.Module):
def __init__(
self,
hidden_size: int,
patch_size: int,
out_channels: int,
use_compiled: bool = False,
):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(
hidden_size, patch_size * patch_size * out_channels, bias=True
)
self.use_compiled = use_compiled
@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device
def modulation_shift_scale_fn(self, x, scale, shift):
if self.use_compiled:
return torch.compile(_modulation_shift_scale_fn)(x, scale, shift)
else:
return _modulation_shift_scale_fn(x, scale, shift)
def forward(self, x: Tensor, distill_vec: list[Tensor]) -> Tensor:
shift, scale = distill_vec
shift = shift.squeeze(1)
scale = scale.squeeze(1)
# replaced with compiled fn
# x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.modulation_shift_scale_fn(
self.norm_final(x), scale[:, None, :], shift[:, None, :]
)
x = self.linear(x)
return x

View File

@@ -0,0 +1,33 @@
import torch
from einops import rearrange
from torch import Tensor
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
# mask should have shape [B, H, L, D]
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack(
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@@ -0,0 +1,273 @@
from dataclasses import dataclass
import torch
from torch import Tensor, nn
import torch.utils.checkpoint as ckpt
from .layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
SingleStreamBlock,
timestep_embedding,
Approximator,
distribute_modulations,
)
@dataclass
class ChromaParams:
in_channels: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
approximator_in_dim: int
approximator_depth: int
approximator_hidden_size: int
_use_compiled: bool
chroma_params = ChromaParams(
in_channels=64,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
approximator_in_dim=64,
approximator_depth=5,
approximator_hidden_size=5120,
_use_compiled=False,
)
def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8):
"""
Modifies attention mask to allow attention to a few extra padding tokens.
Args:
mask: Original attention mask (1 for tokens to attend to, 0 for masked tokens)
max_seq_length: Maximum sequence length of the model
num_extra_padding: Number of padding tokens to unmask
Returns:
Modified mask
"""
# Get the actual sequence length from the mask
seq_length = mask.sum(dim=-1)
batch_size = mask.shape[0]
modified_mask = mask.clone()
for i in range(batch_size):
current_seq_len = int(seq_length[i].item())
# Only add extra padding tokens if there's room
if current_seq_len < max_seq_length:
# Calculate how many padding tokens we can unmask
available_padding = max_seq_length - current_seq_len
tokens_to_unmask = min(num_extra_padding, available_padding)
# Unmask the specified number of padding tokens right after the sequence
modified_mask[i, current_seq_len : current_seq_len + tokens_to_unmask] = 1
return modified_mask
class Chroma(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: ChromaParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
)
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
# TODO: need proper mapping for this approximator output!
# currently the mapping is hardcoded in distribute_modulations function
self.distilled_guidance_layer = Approximator(
params.approximator_in_dim,
self.hidden_size,
params.approximator_hidden_size,
params.approximator_depth,
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
use_compiled=params._use_compiled,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
use_compiled=params._use_compiled,
)
for _ in range(params.depth_single_blocks)
]
)
self.final_layer = LastLayer(
self.hidden_size,
1,
self.out_channels,
use_compiled=params._use_compiled,
)
# TODO: move this hardcoded value to config
self.mod_index_length = 344
# self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0)
self.register_buffer(
"mod_index",
torch.tensor(list(range(self.mod_index_length)), device="cpu"),
persistent=False,
)
@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device
def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
guidance: Tensor,
attn_padding: int = 1,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
txt = self.txt_in(txt)
# TODO:
# need to fix grad accumulation issue here for now it's in no grad mode
# besides, i don't want to wash out the PFP that's trained on this model weights anyway
# the fan out operation here is deleting the backward graph
# alternatively doing forward pass for every block manually is doable but slow
# custom backward probably be better
with torch.no_grad():
distill_timestep = timestep_embedding(timesteps, 16)
# TODO: need to add toggle to omit this from schnell but that's not a priority
distil_guidance = timestep_embedding(guidance, 16)
# get all modulation index
modulation_index = timestep_embedding(self.mod_index, 32)
# we need to broadcast the modulation index here so each batch has all of the index
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1)
# and we need to broadcast timestep and guidance along too
timestep_guidance = (
torch.cat([distill_timestep, distil_guidance], dim=1)
.unsqueeze(1)
.repeat(1, self.mod_index_length, 1)
)
# then and only then we could concatenate it together
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1)
mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True))
mod_vectors_dict = distribute_modulations(mod_vectors)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
# compute mask
# assume max seq length from the batched input
max_len = txt.shape[1]
# mask
with torch.no_grad():
txt_mask_w_padding = modify_mask_to_attend_padding(
txt_mask, max_len, attn_padding
)
txt_img_mask = torch.cat(
[
txt_mask_w_padding,
torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device),
],
dim=1,
)
txt_img_mask = txt_img_mask.float().T @ txt_img_mask.float()
txt_img_mask = (
txt_img_mask[None, None, ...]
.repeat(txt.shape[0], self.num_heads, 1, 1)
.int()
.bool()
)
# txt_mask_w_padding[txt_mask_w_padding==False] = True
for i, block in enumerate(self.double_blocks):
# the guidance replaced by FFN output
img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"]
txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"]
double_mod = [img_mod, txt_mod]
# just in case in different GPU for simple pipeline parallel
if self.training:
img.requires_grad_(True)
img, txt = ckpt.checkpoint(
block, img, txt, pe, double_mod, txt_img_mask
)
else:
img, txt = block(
img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask
)
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
if self.training:
img.requires_grad_(True)
img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask)
else:
img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask)
img = img[:, txt.shape[1] :, ...]
final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"]
img = self.final_layer(
img, distill_vec=final_mod
) # (N, T, patch_size ** 2 * out_channels)
return img