Add support for training lodestones/Zeta-Chroma

This commit is contained in:
Jaret Burkett
2026-03-01 12:52:29 -07:00
parent 40f995f616
commit 57d407cfd4
7 changed files with 1320 additions and 2 deletions

View File

@@ -8,6 +8,7 @@ from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusMod
from .flux2 import Flux2Model, Flux2Klein4BModel, Flux2Klein9BModel
from .z_image import ZImageModel
from .ltx2 import LTX2Model
from .zeta_chroma import ZetaChromaModel
AI_TOOLKIT_MODELS = [
# put a list of models here
@@ -29,4 +30,5 @@ AI_TOOLKIT_MODELS = [
LTX2Model,
Flux2Klein4BModel,
Flux2Klein9BModel,
ZetaChromaModel,
]

View File

@@ -0,0 +1 @@
from .zeta_chroma_model import ZetaChromaModel

View File

@@ -0,0 +1,380 @@
import os
from typing import List, Optional
import huggingface_hub
import torch
import yaml
from toolkit.config_modules import GenerateImageConfig, ModelConfig, NetworkConfig
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.models.base_model import BaseModel
from toolkit.basic import flush
from toolkit.prompt_utils import PromptEmbeds
from toolkit.samplers.custom_flowmatch_sampler import (
CustomFlowMatchEulerDiscreteScheduler,
)
from toolkit.accelerator import unwrap_model
from optimum.quanto import freeze
from toolkit.util.quantize import quantize, get_qtype, quantize_model
from toolkit.memory_management import MemoryManager
from safetensors.torch import load_file
from optimum.quanto import QTensor
from toolkit.metadata import get_meta_for_safetensors
from safetensors.torch import load_file, save_file
from transformers import AutoTokenizer, Qwen3ForCausalLM
from diffusers import AutoencoderKL
from toolkit.models.FakeVAE import FakeVAE
from .zeta_chroma_transformer import ZImageDCT, ZImageDCTParams, vae_flatten, vae_unflatten, prepare_latent_image_ids, make_text_position_ids
from .zeta_chroma_pipeline import ZetaChromaPipeline
scheduler_config = {
"num_train_timesteps": 1000,
"use_dynamic_shifting": False,
"shift": 3.0,
}
ZETA_CHROMA_TRANSFORMER_FILENAME = "zeta-chroma-base-x0-pixel-dino-distance.safetensors"
class ZetaChromaModel(BaseModel):
arch = "zeta_chroma"
def __init__(
self,
device,
model_config: ModelConfig,
dtype="bf16",
custom_pipeline=None,
noise_scheduler=None,
**kwargs,
):
super().__init__(
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ["ZImageDCT"]
self.patch_size = 32
self.max_sequence_length = 512
# static method to get the noise scheduler
@staticmethod
def get_train_scheduler():
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
def get_bucket_divisibility(self):
return self.patch_size
def load_model(self):
dtype = self.torch_dtype
self.print_and_status_update("Loading ZImage model")
model_path = self.model_config.name_or_path
base_model_path = self.model_config.extras_name_or_path
self.print_and_status_update("Loading transformer")
transformer_path = model_path
if not os.path.exists(transformer_path):
transformer_name = ZETA_CHROMA_TRANSFORMER_FILENAME
# if path ends with .safetensors, assume last part is filename
# this allows users to target different file names in the repo like
# lodestones/Zeta-Chroma/zeta-chroma-base-x0-pixel-dino-distance.safetensors
if transformer_path.endswith(".safetensors"):
splits = transformer_path.split("/")
transformer_name = splits[-1]
transformer_path = "/".join(splits[:-1])
# assume it is from the hub
transformer_path = huggingface_hub.hf_hub_download(
repo_id=transformer_path,
filename=transformer_name,
)
transformer_state_dict = load_file(transformer_path, device="cpu")
# cast to dtype
for key in transformer_state_dict:
transformer_state_dict[key] = transformer_state_dict[key].to(dtype)
# Auto-detect use_x0 from checkpoint
use_x0 = "__x0__" in transformer_state_dict
# Build model params
in_channels = self.patch_size * self.patch_size * 3 # RGB patches
model_params = ZImageDCTParams(
patch_size=1,
in_channels=in_channels,
use_x0=use_x0,
)
with torch.device("meta"):
transformer = ZImageDCT(model_params)
transformer.load_state_dict(transformer_state_dict, assign=True)
del transformer_state_dict
transformer.to(self.quantize_device, dtype=dtype)
if self.model_config.quantize:
self.print_and_status_update("Quantizing Transformer")
quantize_model(self, transformer)
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_transformer_percent > 0
):
MemoryManager.attach(
transformer,
self.device_torch,
offload_percent=self.model_config.layer_offloading_transformer_percent,
ignore_modules=[
transformer.x_pad_token,
transformer.cap_pad_token,
],
)
if self.model_config.low_vram:
self.print_and_status_update("Moving transformer to CPU")
transformer.to("cpu")
flush()
self.print_and_status_update("Text Encoder")
tokenizer = AutoTokenizer.from_pretrained(
base_model_path, subfolder="tokenizer", torch_dtype=dtype
)
text_encoder = Qwen3ForCausalLM.from_pretrained(
base_model_path, subfolder="text_encoder", torch_dtype=dtype
)
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_text_encoder_percent > 0
):
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing Text Encoder")
quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te))
freeze(text_encoder)
flush()
self.print_and_status_update("Loading VAE")
vae = FakeVAE(scaling_factor=1.0)
vae.to(self.device_torch, dtype=dtype)
self.noise_scheduler = ZetaChromaModel.get_train_scheduler()
self.print_and_status_update("Making pipe")
kwargs = {}
pipe: ZetaChromaPipeline = ZetaChromaPipeline(
scheduler=self.noise_scheduler,
text_encoder=None,
tokenizer=tokenizer,
vae=vae,
transformer=None,
**kwargs,
)
# for quantization, it works best to do these after making the pipe
pipe.text_encoder = text_encoder
pipe.transformer = transformer
self.print_and_status_update("Preparing Model")
text_encoder = [pipe.text_encoder]
tokenizer = [pipe.tokenizer]
# leave it on cpu for now
if not self.low_vram:
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
# just to make sure everything is on the right device and dtype
text_encoder[0].to(self.device_torch)
text_encoder[0].requires_grad_(False)
text_encoder[0].eval()
flush()
# save it to the model class
self.vae = vae
self.text_encoder = text_encoder # list of text encoders
self.tokenizer = tokenizer # list of tokenizers
self.model = pipe.transformer
self.pipeline = pipe
self.print_and_status_update("Model Loaded")
def get_generation_pipeline(self):
scheduler = ZetaChromaModel.get_train_scheduler()
pipeline: ZetaChromaPipeline = ZetaChromaPipeline(
scheduler=scheduler,
text_encoder=unwrap_model(self.text_encoder[0]),
tokenizer=self.tokenizer[0],
vae=unwrap_model(self.vae),
transformer=unwrap_model(self.transformer),
)
pipeline = pipeline.to(self.device_torch)
return pipeline
def generate_single_image(
self,
pipeline: ZetaChromaPipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
self.model.to(self.device_torch, dtype=self.torch_dtype)
self.model.to(self.device_torch)
sc = self.get_bucket_divisibility()
gen_config.width = int(gen_config.width // sc * sc)
gen_config.height = int(gen_config.height // sc * sc)
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds,
prompt_embeds_mask=conditional_embeds.attention_mask,
negative_prompt_embeds=unconditional_embeds.text_embeds,
negative_prompt_embeds_mask=unconditional_embeds.attention_mask,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
generator=generator,
**extra,
).images[0]
return img
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs,
):
if self.model.device == torch.device("cpu"):
self.model.to(self.device_torch)
with torch.no_grad():
pixel_shape = latent_model_input.shape
# todo: do we invert like this?
# t_vec = (1000 - timestep) / 1000
t_vec = timestep / 1000
height = latent_model_input.shape[2]
h_patches = height // self.patch_size
width = latent_model_input.shape[3]
w_patches = width // self.patch_size
batch_size = latent_model_input.shape[0]
img, _ = vae_flatten(latent_model_input, patch_size=self.patch_size)
num_patches = img.shape[1]
# --- Build position IDs ---
pos_lengths = text_embeddings.attention_mask.sum(1)
offset = pos_lengths
image_pos_ids = prepare_latent_image_ids(
offset, h_patches, w_patches, patch_size=1
).to(self.device_torch)
pos_text_ids = make_text_position_ids(pos_lengths, self.max_sequence_length).to(
self.device_torch
)
img_mask = torch.ones(
(batch_size, num_patches), device=self.device_torch, dtype=torch.bool
)
# model_out_list = self.transformer(
# latent_model_input_list,
# t_vec,
# text_embeddings.text_embeds,
# )[0]
pred = self.transformer(
img=img, #(1, 1024, 3072)
img_ids=image_pos_ids, # (1, 1024, 3)
img_mask=img_mask, # (1, 1024)
txt=text_embeddings.text_embeds, # (1, 512, 2560)
txt_ids=pos_text_ids, # (1, 512, 3)
txt_mask=text_embeddings.attention_mask, # (1, 512)
timesteps=t_vec, # (1,)
)
pred = vae_unflatten(pred.float(), pixel_shape, patch_size=self.patch_size)
return pred
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
if self.pipeline.text_encoder.device != self.device_torch:
self.pipeline.text_encoder.to(self.device_torch)
prompt_embeds, mask = self.pipeline._encode_prompts(
prompt,
)
pe = PromptEmbeds([prompt_embeds, None], attention_mask=mask)
return pe
def get_model_has_grad(self):
return False
def get_te_has_grad(self):
return False
def save_model(self, output_path, meta, save_dtype):
if not output_path.endswith(".safetensors"):
output_path = output_path + ".safetensors"
# only save the unet
transformer: ZImageDCT = unwrap_model(self.model)
state_dict = transformer.state_dict()
save_dict = {}
for k, v in state_dict.items():
if isinstance(v, QTensor):
v = v.dequantize()
save_dict[k] = v.clone().to("cpu", dtype=save_dtype)
meta = get_meta_for_safetensors(meta, name="zeta_chroma")
save_file(save_dict, output_path, metadata=meta)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get("noise")
batch = kwargs.get("batch")
return (noise - batch.latents).detach()
def get_base_model_version(self):
return "zeta_chroma"
def get_transformer_block_names(self) -> Optional[List[str]]:
return ["layers"]
def convert_lora_weights_before_save(self, state_dict):
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("transformer.", "diffusion_model.")
new_sd[new_key] = value
return new_sd
def convert_lora_weights_before_load(self, state_dict):
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("diffusion_model.", "transformer.")
new_sd[new_key] = value
return new_sd

View File

@@ -0,0 +1,180 @@
from diffusers.pipelines.z_image.pipeline_z_image import (
ZImagePipeline,
calculate_shift,
retrieve_timesteps,
)
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from diffusers.utils.torch_utils import randn_tensor
import torch
from diffusers.utils import logging, replace_example_docstring
from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput
from extensions_built_in.diffusion_models.zeta_chroma.zeta_chroma_transformer import (
get_schedule,
prepare_latent_image_ids,
make_text_position_ids,
vae_unflatten,
)
class ZetaChromaPipeline(ZImagePipeline):
need_something_here = True
patch_size = 32
max_sequence_length = 512
@torch.no_grad()
def _encode_prompts(self, prompts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode a list of prompts with the Qwen3 chat template."""
formatted = []
for p in prompts:
messages = [{"role": "user", "content": p}]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
formatted.append(text)
inputs = self.tokenizer(
formatted,
padding="max_length",
max_length=self.max_sequence_length,
truncation=True,
return_tensors="pt",
).to(self.text_encoder.device)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
outputs = self.text_encoder(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
output_hidden_states=True,
)
# Second-to-last hidden state (same as training)
embeddings = outputs.hidden_states[-2]
mask = inputs.attention_mask.bool()
return embeddings, mask
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 5.0,
cfg_normalization: bool = False,
cfg_truncation: float = 1.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
prompt_embeds_mask: Optional[torch.BoolTensor] = None,
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds_mask: Optional[torch.BoolTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
device = self._execution_device
batch_size = len(prompt_embeds)
device = self._execution_device
patch_size = self.patch_size
in_channels = patch_size * patch_size * 3
h_patches = height // patch_size
w_patches = width // patch_size
num_patches = h_patches * w_patches
pos_embeds, pos_mask = prompt_embeds, prompt_embeds_mask
neg_embeds, neg_mask = negative_prompt_embeds, negative_prompt_embeds_mask
# --- Build position IDs ---
pos_lengths = pos_mask.sum(1)
neg_lengths = neg_mask.sum(1)
offset = torch.maximum(pos_lengths, neg_lengths)
image_pos_ids = prepare_latent_image_ids(
offset, h_patches, w_patches, patch_size=1
).to(device)
pos_text_ids = make_text_position_ids(pos_lengths, max_sequence_length).to(
device
)
neg_text_ids = make_text_position_ids(neg_lengths, max_sequence_length).to(
device
)
# --- Initial noise ---
noise = randn_tensor(
(batch_size, num_patches, in_channels),
generator=generator,
device=device,
dtype=self.transformer.dtype,
)
# --- Timestep schedule ---
timesteps = get_schedule(num_inference_steps, num_patches)
# --- Denoising loop (CFG) ---
img = noise
img_mask = torch.ones(
(batch_size, num_patches), device=device, dtype=torch.bool
)
self._num_timesteps = len(timesteps)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
t_vec = torch.full(
(batch_size,), t_curr, dtype=self.dtype, device=device
)
pred = self.transformer(
img=img,
img_ids=image_pos_ids,
img_mask=img_mask,
txt=pos_embeds,
txt_ids=pos_text_ids,
txt_mask=pos_mask,
timesteps=t_vec,
)
if guidance_scale > 1.0:
pred_neg = self.transformer(
img=img,
img_ids=image_pos_ids,
img_mask=img_mask,
txt=neg_embeds,
txt_ids=neg_text_ids,
txt_mask=neg_mask,
timesteps=t_vec,
)
pred = pred_neg + guidance_scale * (pred - pred_neg)
img = img + (t_prev - t_curr) * pred
progress_bar.update()
if output_type == "latent":
image = img
else:
# --- Unpatchify: [B, num_patches, C*P*P] -> [B, 3, H, W] ---
pixel_shape = (batch_size, 3, height, width)
image = vae_unflatten(img.float(), pixel_shape, patch_size=patch_size)
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ZImagePipelineOutput(images=image)

View File

@@ -0,0 +1,739 @@
# orig code provided by lodestones, altered for ai-toolkit
from dataclasses import dataclass, field
from functools import lru_cache
from typing import List, Optional
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from einops import rearrange
import torch.utils.checkpoint as ckpt
@dataclass
class ZImageDCTParams:
patch_size: int = 1
f_patch_size: int = 1
in_channels: int = 128
dim: int = 3840
n_layers: int = 30
n_refiner_layers: int = 2
n_heads: int = 30
n_kv_heads: int = 30
norm_eps: float = 1e-5
qk_norm: bool = True
cap_feat_dim: int = 2560
rope_theta: int = 256
t_scale: float = 1000.0
axes_dims: list = field(default_factory=lambda: [32, 48, 48])
axes_lens: list = field(default_factory=lambda: [1536, 512, 512])
adaln_embed_dim: int = 256
use_x0: bool = True
# DCT decoder params
decoder_hidden_size: int = 3840
decoder_num_res_blocks: int = 4
decoder_max_freqs: int = 8
class FakeConfig:
# for diffusers compatability
def __init__(self):
self.patch_size = 1
def _process_mask(attn_mask: Optional[torch.Tensor], dtype: torch.dtype):
if attn_mask is None:
return None
if attn_mask.ndim == 2:
attn_mask = attn_mask[:, None, None, :]
if attn_mask.dtype == torch.bool:
new_mask = torch.zeros_like(attn_mask, dtype=dtype)
new_mask.masked_fill_(~attn_mask, float("-inf"))
return new_mask
return attn_mask
def _native_attention_wrapper(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_mask = _process_mask(attn_mask, query.dtype)
out = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
return out.transpose(1, 2).contiguous()
class TimestepEmbedder(nn.Module):
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
super().__init__()
if mid_size is None:
mid_size = out_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, mid_size, bias=True),
nn.SiLU(),
nn.Linear(mid_size, out_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
with torch.amp.autocast("cuda", enabled=False):
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
/ half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
weight_dtype = self.mlp[0].weight.dtype
if weight_dtype.is_floating_point:
t_freq = t_freq.to(weight_dtype)
return self.mlp(t_freq)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return output * self.weight
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in)
class ZImageAttention(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
qk_norm: bool = True,
eps: float = 1e-5,
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = dim // n_heads
self.to_q = nn.Linear(dim, n_heads * self.head_dim, bias=False)
self.to_k = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.to_v = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.to_out = nn.ModuleList(
[nn.Linear(n_heads * self.head_dim, dim, bias=False)]
)
self.norm_q = RMSNorm(self.head_dim, eps=eps) if qk_norm else None
self.norm_k = RMSNorm(self.head_dim, eps=eps) if qk_norm else None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
query = query.unflatten(-1, (self.n_heads, -1))
key = key.unflatten(-1, (self.n_kv_heads, -1))
value = value.unflatten(-1, (self.n_kv_heads, -1))
if self.norm_q is not None:
query = self.norm_q(query)
if self.norm_k is not None:
key = self.norm_k(key)
if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
dtype = query.dtype
query, key = query.to(dtype), key.to(dtype)
hidden_states = _native_attention_wrapper(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.flatten(2, 3).to(dtype)
return self.to_out[0](hidden_states)
class ZImageTransformerBlock(nn.Module):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
norm_eps: float,
qk_norm: bool,
modulation: bool = True,
adaln_embed_dim: int = 256,
):
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
self.layer_id = layer_id
self.modulation = modulation
self.attention = ZImageAttention(dim, n_heads, n_kv_heads, qk_norm, norm_eps)
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
if modulation:
self.adaLN_modulation = nn.ModuleList(
[nn.Linear(min(dim, adaln_embed_dim), 4 * dim, bias=True)]
)
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None,
):
if self.modulation:
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = (
self.adaLN_modulation[0](adaln_input).unsqueeze(1).chunk(4, dim=2)
)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
attn_out = self.attention(
self.attention_norm1(x) * scale_msa,
attention_mask=attn_mask,
freqs_cis=freqs_cis,
)
x = x + gate_msa * self.attention_norm2(attn_out)
x = x + gate_mlp * self.ffn_norm2(
self.feed_forward(self.ffn_norm1(x) * scale_mlp)
)
else:
attn_out = self.attention(
self.attention_norm1(x),
attention_mask=attn_mask,
freqs_cis=freqs_cis,
)
x = x + self.attention_norm2(attn_out)
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
return x
class RopeEmbedder:
def __init__(
self,
theta: float = 256,
axes_dims: List[int] = None,
axes_lens: List[int] = None,
):
self.theta = theta
self.axes_dims = axes_dims or [32, 48, 48]
self.axes_lens = axes_lens or [1536, 512, 512]
assert len(self.axes_dims) == len(self.axes_lens)
self.freqs_cis = None
@staticmethod
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256):
with torch.device("cpu"):
freqs_cis = []
for d, e in zip(dim, end):
freqs = 1.0 / (
theta
** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)
)
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(
torch.complex64
)
freqs_cis.append(freqs_cis_i)
return freqs_cis
def __call__(self, ids: torch.Tensor):
assert ids.ndim >= 2 and ids.shape[-1] == len(self.axes_dims)
device = ids.device
if self.freqs_cis is None:
self.freqs_cis = self.precompute_freqs_cis(
self.axes_dims, self.axes_lens, theta=self.theta
)
self.freqs_cis = [f.to(device) for f in self.freqs_cis]
elif self.freqs_cis[0].device != device:
self.freqs_cis = [f.to(device) for f in self.freqs_cis]
return torch.cat(
[self.freqs_cis[i][ids[..., i]] for i in range(len(self.axes_dims))], dim=-1
)
# --- Decoder components ---
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class NerfEmbedder(nn.Module):
def __init__(self, in_channels, hidden_size_input, max_freqs):
super().__init__()
self.max_freqs = max_freqs
self.hidden_size_input = hidden_size_input
self.embedder = nn.Sequential(
nn.Linear(in_channels + max_freqs**2, hidden_size_input)
)
@lru_cache(maxsize=4)
def fetch_pos(self, patch_size, device, dtype):
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
pos_x = pos_x.reshape(-1, 1, 1)
pos_y = pos_y.reshape(-1, 1, 1)
freqs = torch.linspace(
0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device
)
freqs_x = freqs[None, :, None]
freqs_y = freqs[None, None, :]
coeffs = (1 + freqs_x * freqs_y) ** -1
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs**2)
return dct
def forward(self, inputs):
B, P2, C = inputs.shape
original_dtype = inputs.dtype
with torch.autocast("cuda", enabled=False):
patch_size = int(P2**0.5)
inputs = inputs.float()
dct = self.fetch_pos(patch_size, inputs.device, torch.float32)
dct = dct.repeat(B, 1, 1)
inputs = torch.cat([inputs, dct], dim=-1)
inputs = self.embedder.float()(inputs)
return inputs.to(original_dtype)
class ResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.in_ln = nn.LayerNorm(channels, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(channels, channels, bias=True),
nn.SiLU(),
nn.Linear(channels, channels, bias=True),
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(channels, 3 * channels, bias=True),
)
self._init_weights()
def _init_weights(self):
for m in self.mlp:
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="linear")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
def forward(self, x, y):
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
h = self.mlp(h)
return x + gate_mlp * h
class DCTFinalLayer(nn.Module):
def __init__(self, model_channels, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(
model_channels, elementwise_affine=False, eps=1e-6
)
self.linear = nn.Linear(model_channels, out_channels, bias=True)
nn.init.constant_(self.linear.weight, 0)
nn.init.constant_(self.linear.bias, 0)
def forward(self, x):
return self.linear(self.norm_final(x))
class SimpleMLPAdaLN(nn.Module):
def __init__(
self,
in_channels,
model_channels,
out_channels,
z_channels,
num_res_blocks,
patch_size,
max_freqs=8,
):
super().__init__()
self.patch_size = patch_size
self.cond_embed = nn.Linear(z_channels, patch_size**2 * model_channels)
self.input_embedder = NerfEmbedder(
in_channels=in_channels,
hidden_size_input=model_channels,
max_freqs=max_freqs,
)
self.res_blocks = nn.ModuleList(
[ResBlock(model_channels) for _ in range(num_res_blocks)]
)
self.final_layer = DCTFinalLayer(model_channels, out_channels)
nn.init.xavier_uniform_(self.cond_embed.weight)
nn.init.constant_(self.cond_embed.bias, 0)
def forward(self, x, c):
x = self.input_embedder(x)
c = self.cond_embed(c)
y = c.reshape(c.shape[0], self.patch_size**2, -1)
for block in self.res_blocks:
x = block(x, y)
return self.final_layer(x)
class ZImageDCT(nn.Module):
def __init__(self, params: ZImageDCTParams):
super().__init__()
self.config = FakeConfig()
self.in_channels = params.in_channels
self.out_channels = params.in_channels
self.patch_size = params.patch_size
self.f_patch_size = params.f_patch_size
self.dim = params.dim
self.n_heads = params.n_heads
self.rope_theta = params.rope_theta
self.t_scale = params.t_scale
self.adaln_embed_dim = params.adaln_embed_dim
self.x_embedder = nn.Linear(
self.f_patch_size * self.patch_size * self.patch_size * params.in_channels,
params.dim,
bias=True,
)
self.noise_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
1000 + i,
params.dim,
params.n_heads,
params.n_kv_heads,
params.norm_eps,
params.qk_norm,
modulation=True,
adaln_embed_dim=params.adaln_embed_dim,
)
for i in range(params.n_refiner_layers)
]
)
self.context_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
i,
params.dim,
params.n_heads,
params.n_kv_heads,
params.norm_eps,
params.qk_norm,
modulation=False,
)
for i in range(params.n_refiner_layers)
]
)
self.t_embedder = TimestepEmbedder(
min(params.dim, params.adaln_embed_dim), mid_size=1024
)
self.cap_embedder = nn.Sequential(
RMSNorm(params.cap_feat_dim, eps=params.norm_eps),
nn.Linear(params.cap_feat_dim, params.dim, bias=True),
)
self.x_pad_token = nn.Parameter(torch.empty((1, params.dim)))
self.cap_pad_token = nn.Parameter(torch.empty((1, params.dim)))
self.layers = nn.ModuleList(
[
ZImageTransformerBlock(
i,
params.dim,
params.n_heads,
params.n_kv_heads,
params.norm_eps,
params.qk_norm,
modulation=True,
adaln_embed_dim=params.adaln_embed_dim,
)
for i in range(params.n_layers)
]
)
head_dim = params.dim // params.n_heads
assert head_dim == sum(params.axes_dims)
self.axes_dims = params.axes_dims
self.axes_lens = params.axes_lens
self.rope_embedder = RopeEmbedder(
theta=params.rope_theta,
axes_dims=params.axes_dims,
axes_lens=params.axes_lens,
)
self.dec_net = SimpleMLPAdaLN(
in_channels=params.in_channels,
model_channels=params.decoder_hidden_size,
out_channels=params.in_channels,
z_channels=params.dim,
num_res_blocks=params.decoder_num_res_blocks,
patch_size=self.patch_size,
max_freqs=params.decoder_max_freqs,
)
if params.use_x0:
self.register_buffer("__x0__", torch.tensor([]))
self.gradient_checkpointing = False
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True
def _forward(
self,
img: Tensor,
img_ids: Tensor,
img_mask: Tensor,
txt: Tensor,
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
):
B = img.shape[0]
num_patches = img.shape[1]
pixel_values = img.reshape(
B * num_patches, self.patch_size**2, self.in_channels
)
timesteps = (1 - timesteps) * self.t_scale
timesteps_embedding = self.t_embedder(timesteps)
img_hidden = self.x_embedder(img)
txt_hidden = self.cap_embedder(txt)
img_pe = self.rope_embedder(img_ids)
txt_pe = self.rope_embedder(txt_ids)
for layer in self.noise_refiner:
img_hidden = layer(img_hidden, img_mask, img_pe, timesteps_embedding)
for layer in self.context_refiner:
txt_hidden = layer(txt_hidden, txt_mask, txt_pe)
mixed_hidden = torch.cat((txt_hidden, img_hidden), 1)
mixed_mask = torch.cat((txt_mask, img_mask), 1)
mixed_pe = torch.cat((txt_pe, img_pe), 1)
for layer in self.layers:
if torch.is_grad_enabled() and self.gradient_checkpointing:
mixed_hidden = ckpt.checkpoint(
layer,
mixed_hidden,
mixed_mask,
mixed_pe,
timesteps_embedding,
use_reentrant=False,
)
else:
mixed_hidden = layer(
mixed_hidden, mixed_mask, mixed_pe, timesteps_embedding
)
img_hidden = mixed_hidden[:, txt.shape[1] :, ...]
decoder_condition = img_hidden.reshape(B * num_patches, self.dim)
output = self.dec_net(pixel_values, decoder_condition)
output = output.reshape(B, num_patches, -1)
return -output
def _apply_x0_residual(self, predicted, noisy, timesteps):
return (noisy - predicted) / timesteps.view(-1, 1, 1)
def forward(
self,
img: Tensor,
img_ids: Tensor,
img_mask: Tensor,
txt: Tensor,
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
):
out = self._forward(
img=img,
img_ids=img_ids,
img_mask=img_mask,
txt=txt,
txt_ids=txt_ids,
txt_mask=txt_mask,
timesteps=timesteps,
)
if hasattr(self, "__x0__"):
return self._apply_x0_residual(out, img, timesteps)
return out
def vae_flatten(latents, patch_size=2):
"""Patchify: [N, C, H, W] -> ([N, num_patches, patch_size*patch_size*C], original_shape)"""
return (
rearrange(
latents,
"n c (h dh) (w dw) -> n (h w) (dh dw c)",
dh=patch_size,
dw=patch_size,
),
latents.shape,
)
def vae_unflatten(latents, shape, patch_size=2):
"""Unpatchify: [N, num_patches, patch_size*patch_size*C] -> [N, C, H, W]"""
n, c, h, w = shape
return rearrange(
latents,
"n (h w) (dh dw c) -> n c (h dh) (w dw)",
dh=patch_size,
dw=patch_size,
c=c,
h=h // patch_size,
w=w // patch_size,
)
def prepare_latent_image_ids(start_indices, height, width, patch_size=2, max_offset=0):
"""Generate 3D positional IDs for image patches."""
if isinstance(start_indices, list):
start_indices = torch.tensor(start_indices)
batch_size = len(start_indices)
latent_image_ids = torch.zeros(height // patch_size, width // patch_size, 3)
latent_image_ids[..., 1] = (
latent_image_ids[..., 1] + torch.arange(height // patch_size)[:, None]
)
latent_image_ids[..., 2] = (
latent_image_ids[..., 2] + torch.arange(width // patch_size)[None, :]
)
h, w, ch = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
for i, start_idx in enumerate(start_indices):
latent_image_ids[i, :, :, 0] = start_idx
return latent_image_ids.reshape(batch_size, h * w, ch).int()
def make_text_position_ids(valid_len, max_sequence_length, extra_padding=0):
"""Generate 3D positional IDs for text tokens."""
device = valid_len.device
valid_len = valid_len + extra_padding
B = valid_len.shape[0]
seq = (
torch.arange(1, max_sequence_length + 1, device=device)
.unsqueeze(0)
.expand(B, -1)
)
increment_then_repeat = torch.minimum(seq, valid_len.unsqueeze(1))
pos_ids = torch.zeros((B, max_sequence_length, 3), device=device)
pos_ids[:, :, 0] = increment_then_repeat
return pos_ids.int()
def time_shift(mu: float, sigma: float, t: Tensor) -> Tensor:
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list:
"""Build a shifted cosine timestep schedule from t=1 (noise) to t=0 (clean)."""
timesteps = torch.linspace(1, 0, num_steps + 1)
if shift:
m = (max_shift - base_shift) / (4096 - 256)
b = base_shift - m * 256
mu = m * image_seq_len + b
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()

View File

@@ -29,7 +29,8 @@ type AdditionalSections =
| 'model.low_vram'
| 'model.qie.match_target_res'
| 'model.assistant_lora_path';
type ModelGroup = 'image' | 'instruction' | 'video';
type ModelGroup = 'image' | 'instruction' | 'video' | 'experimental';
export interface ModelArch {
name: string;
@@ -133,6 +134,21 @@ export const modelArchs: ModelArch[] = [
},
disableSections: ['network.conv'],
},
{
name: 'zeta_chroma',
label: 'Zeta Chroma',
group: 'experimental',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['lodestones/Zeta-Chroma/zeta-chroma-base-x0-pixel-dino-distance.safetensors', defaultNameOrPath],
'config.process[0].model.extras_name_or_path': ['Tongyi-MAI/Z-Image-Turbo', undefined],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
disableSections: ['network.conv'],
},
{
name: 'wan21:1b',
label: 'Wan 2.1 (1.3B)',

View File

@@ -1 +1 @@
VERSION = "0.7.23"
VERSION = "0.7.24"