diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index 6d20efa5..88d1bed3 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -8,6 +8,7 @@ from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusMod from .flux2 import Flux2Model, Flux2Klein4BModel, Flux2Klein9BModel from .z_image import ZImageModel from .ltx2 import LTX2Model +from .zeta_chroma import ZetaChromaModel AI_TOOLKIT_MODELS = [ # put a list of models here @@ -29,4 +30,5 @@ AI_TOOLKIT_MODELS = [ LTX2Model, Flux2Klein4BModel, Flux2Klein9BModel, + ZetaChromaModel, ] diff --git a/extensions_built_in/diffusion_models/zeta_chroma/__init__.py b/extensions_built_in/diffusion_models/zeta_chroma/__init__.py new file mode 100644 index 00000000..b9067709 --- /dev/null +++ b/extensions_built_in/diffusion_models/zeta_chroma/__init__.py @@ -0,0 +1 @@ +from .zeta_chroma_model import ZetaChromaModel \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_model.py b/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_model.py new file mode 100644 index 00000000..bddbf09d --- /dev/null +++ b/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_model.py @@ -0,0 +1,380 @@ +import os +from typing import List, Optional + +import huggingface_hub +import torch +import yaml +from toolkit.config_modules import GenerateImageConfig, ModelConfig, NetworkConfig +from toolkit.lora_special import LoRASpecialNetwork +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze +from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.memory_management import MemoryManager +from safetensors.torch import load_file +from optimum.quanto import QTensor +from toolkit.metadata import get_meta_for_safetensors +from safetensors.torch import load_file, save_file +from transformers import AutoTokenizer, Qwen3ForCausalLM +from diffusers import AutoencoderKL +from toolkit.models.FakeVAE import FakeVAE +from .zeta_chroma_transformer import ZImageDCT, ZImageDCTParams, vae_flatten, vae_unflatten, prepare_latent_image_ids, make_text_position_ids +from .zeta_chroma_pipeline import ZetaChromaPipeline + + + +scheduler_config = { + "num_train_timesteps": 1000, + "use_dynamic_shifting": False, + "shift": 3.0, +} + +ZETA_CHROMA_TRANSFORMER_FILENAME = "zeta-chroma-base-x0-pixel-dino-distance.safetensors" + + +class ZetaChromaModel(BaseModel): + arch = "zeta_chroma" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["ZImageDCT"] + self.patch_size = 32 + self.max_sequence_length = 512 + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return self.patch_size + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading ZImage model") + model_path = self.model_config.name_or_path + base_model_path = self.model_config.extras_name_or_path + + self.print_and_status_update("Loading transformer") + + + transformer_path = model_path + if not os.path.exists(transformer_path): + transformer_name = ZETA_CHROMA_TRANSFORMER_FILENAME + # if path ends with .safetensors, assume last part is filename + # this allows users to target different file names in the repo like + # lodestones/Zeta-Chroma/zeta-chroma-base-x0-pixel-dino-distance.safetensors + + if transformer_path.endswith(".safetensors"): + splits = transformer_path.split("/") + transformer_name = splits[-1] + transformer_path = "/".join(splits[:-1]) + # assume it is from the hub + transformer_path = huggingface_hub.hf_hub_download( + repo_id=transformer_path, + filename=transformer_name, + ) + + transformer_state_dict = load_file(transformer_path, device="cpu") + + # cast to dtype + for key in transformer_state_dict: + transformer_state_dict[key] = transformer_state_dict[key].to(dtype) + + # Auto-detect use_x0 from checkpoint + use_x0 = "__x0__" in transformer_state_dict + + # Build model params + in_channels = self.patch_size * self.patch_size * 3 # RGB patches + model_params = ZImageDCTParams( + patch_size=1, + in_channels=in_channels, + use_x0=use_x0, + ) + + with torch.device("meta"): + transformer = ZImageDCT(model_params) + + transformer.load_state_dict(transformer_state_dict, assign=True) + del transformer_state_dict + + transformer.to(self.quantize_device, dtype=dtype) + + if self.model_config.quantize: + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_transformer_percent > 0 + ): + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ignore_modules=[ + transformer.x_pad_token, + transformer.cap_pad_token, + ], + ) + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to("cpu") + + flush() + + self.print_and_status_update("Text Encoder") + tokenizer = AutoTokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", torch_dtype=dtype + ) + text_encoder = Qwen3ForCausalLM.from_pretrained( + base_model_path, subfolder="text_encoder", torch_dtype=dtype + ) + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Text Encoder") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + + self.print_and_status_update("Loading VAE") + vae = FakeVAE(scaling_factor=1.0) + vae.to(self.device_torch, dtype=dtype) + + self.noise_scheduler = ZetaChromaModel.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + kwargs = {} + + pipe: ZetaChromaPipeline = ZetaChromaPipeline( + scheduler=self.noise_scheduler, + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + transformer=None, + **kwargs, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder] + tokenizer = [pipe.tokenizer] + + # leave it on cpu for now + if not self.low_vram: + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = ZetaChromaModel.get_train_scheduler() + + pipeline: ZetaChromaPipeline = ZetaChromaPipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: ZetaChromaPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + self.model.to(self.device_torch, dtype=self.torch_dtype) + self.model.to(self.device_torch) + + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + prompt_embeds_mask=conditional_embeds.attention_mask, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_prompt_embeds_mask=unconditional_embeds.attention_mask, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra, + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + with torch.no_grad(): + + pixel_shape = latent_model_input.shape + # todo: do we invert like this? + # t_vec = (1000 - timestep) / 1000 + t_vec = timestep / 1000 + + height = latent_model_input.shape[2] + h_patches = height // self.patch_size + width = latent_model_input.shape[3] + w_patches = width // self.patch_size + batch_size = latent_model_input.shape[0] + + img, _ = vae_flatten(latent_model_input, patch_size=self.patch_size) + + num_patches = img.shape[1] + + # --- Build position IDs --- + pos_lengths = text_embeddings.attention_mask.sum(1) + offset = pos_lengths + + image_pos_ids = prepare_latent_image_ids( + offset, h_patches, w_patches, patch_size=1 + ).to(self.device_torch) + pos_text_ids = make_text_position_ids(pos_lengths, self.max_sequence_length).to( + self.device_torch + ) + img_mask = torch.ones( + (batch_size, num_patches), device=self.device_torch, dtype=torch.bool + ) + + + + # model_out_list = self.transformer( + # latent_model_input_list, + # t_vec, + # text_embeddings.text_embeds, + # )[0] + pred = self.transformer( + img=img, #(1, 1024, 3072) + img_ids=image_pos_ids, # (1, 1024, 3) + img_mask=img_mask, # (1, 1024) + txt=text_embeddings.text_embeds, # (1, 512, 2560) + txt_ids=pos_text_ids, # (1, 512, 3) + txt_mask=text_embeddings.attention_mask, # (1, 512) + timesteps=t_vec, # (1,) + ) + + pred = vae_unflatten(pred.float(), pixel_shape, patch_size=self.patch_size) + + return pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + prompt_embeds, mask = self.pipeline._encode_prompts( + prompt, + ) + pe = PromptEmbeds([prompt_embeds, None], attention_mask=mask) + + return pe + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + if not output_path.endswith(".safetensors"): + output_path = output_path + ".safetensors" + # only save the unet + transformer: ZImageDCT = unwrap_model(self.model) + state_dict = transformer.state_dict() + save_dict = {} + for k, v in state_dict.items(): + if isinstance(v, QTensor): + v = v.dequantize() + save_dict[k] = v.clone().to("cpu", dtype=save_dtype) + + meta = get_meta_for_safetensors(meta, name="zeta_chroma") + save_file(save_dict, output_path, metadata=meta) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get("noise") + batch = kwargs.get("batch") + return (noise - batch.latents).detach() + + def get_base_model_version(self): + return "zeta_chroma" + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["layers"] + + def convert_lora_weights_before_save(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd diff --git a/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_pipeline.py b/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_pipeline.py new file mode 100644 index 00000000..2df2e6d6 --- /dev/null +++ b/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_pipeline.py @@ -0,0 +1,180 @@ +from diffusers.pipelines.z_image.pipeline_z_image import ( + ZImagePipeline, + calculate_shift, + retrieve_timesteps, +) +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from diffusers.utils.torch_utils import randn_tensor +import torch +from diffusers.utils import logging, replace_example_docstring +from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput +from extensions_built_in.diffusion_models.zeta_chroma.zeta_chroma_transformer import ( + get_schedule, + prepare_latent_image_ids, + make_text_position_ids, + vae_unflatten, +) + + +class ZetaChromaPipeline(ZImagePipeline): + need_something_here = True + patch_size = 32 + max_sequence_length = 512 + + @torch.no_grad() + def _encode_prompts(self, prompts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode a list of prompts with the Qwen3 chat template.""" + formatted = [] + for p in prompts: + messages = [{"role": "user", "content": p}] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + formatted.append(text) + + inputs = self.tokenizer( + formatted, + padding="max_length", + max_length=self.max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(self.text_encoder.device) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = self.text_encoder( + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + output_hidden_states=True, + ) + + # Second-to-last hidden state (same as training) + embeddings = outputs.hidden_states[-2] + mask = inputs.attention_mask.bool() + return embeddings, mask + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + prompt_embeds_mask: Optional[torch.BoolTensor] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds_mask: Optional[torch.BoolTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + device = self._execution_device + + batch_size = len(prompt_embeds) + device = self._execution_device + patch_size = self.patch_size + in_channels = patch_size * patch_size * 3 + + h_patches = height // patch_size + w_patches = width // patch_size + num_patches = h_patches * w_patches + + pos_embeds, pos_mask = prompt_embeds, prompt_embeds_mask + neg_embeds, neg_mask = negative_prompt_embeds, negative_prompt_embeds_mask + + # --- Build position IDs --- + pos_lengths = pos_mask.sum(1) + neg_lengths = neg_mask.sum(1) + offset = torch.maximum(pos_lengths, neg_lengths) + + image_pos_ids = prepare_latent_image_ids( + offset, h_patches, w_patches, patch_size=1 + ).to(device) + pos_text_ids = make_text_position_ids(pos_lengths, max_sequence_length).to( + device + ) + neg_text_ids = make_text_position_ids(neg_lengths, max_sequence_length).to( + device + ) + + # --- Initial noise --- + noise = randn_tensor( + (batch_size, num_patches, in_channels), + generator=generator, + device=device, + dtype=self.transformer.dtype, + ) + + # --- Timestep schedule --- + timesteps = get_schedule(num_inference_steps, num_patches) + + # --- Denoising loop (CFG) --- + img = noise + img_mask = torch.ones( + (batch_size, num_patches), device=device, dtype=torch.bool + ) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): + + t_vec = torch.full( + (batch_size,), t_curr, dtype=self.dtype, device=device + ) + + pred = self.transformer( + img=img, + img_ids=image_pos_ids, + img_mask=img_mask, + txt=pos_embeds, + txt_ids=pos_text_ids, + txt_mask=pos_mask, + timesteps=t_vec, + ) + + if guidance_scale > 1.0: + pred_neg = self.transformer( + img=img, + img_ids=image_pos_ids, + img_mask=img_mask, + txt=neg_embeds, + txt_ids=neg_text_ids, + txt_mask=neg_mask, + timesteps=t_vec, + ) + pred = pred_neg + guidance_scale * (pred - pred_neg) + + img = img + (t_prev - t_curr) * pred + + progress_bar.update() + + if output_type == "latent": + image = img + + else: + # --- Unpatchify: [B, num_patches, C*P*P] -> [B, 3, H, W] --- + pixel_shape = (batch_size, 3, height, width) + image = vae_unflatten(img.float(), pixel_shape, patch_size=patch_size) + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_transformer.py b/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_transformer.py new file mode 100644 index 00000000..450d8664 --- /dev/null +++ b/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_transformer.py @@ -0,0 +1,739 @@ +# orig code provided by lodestones, altered for ai-toolkit + +from dataclasses import dataclass, field +from functools import lru_cache +from typing import List, Optional +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from einops import rearrange +import torch.utils.checkpoint as ckpt + + +@dataclass +class ZImageDCTParams: + patch_size: int = 1 + f_patch_size: int = 1 + in_channels: int = 128 + dim: int = 3840 + n_layers: int = 30 + n_refiner_layers: int = 2 + n_heads: int = 30 + n_kv_heads: int = 30 + norm_eps: float = 1e-5 + qk_norm: bool = True + cap_feat_dim: int = 2560 + rope_theta: int = 256 + t_scale: float = 1000.0 + axes_dims: list = field(default_factory=lambda: [32, 48, 48]) + axes_lens: list = field(default_factory=lambda: [1536, 512, 512]) + adaln_embed_dim: int = 256 + use_x0: bool = True + # DCT decoder params + decoder_hidden_size: int = 3840 + decoder_num_res_blocks: int = 4 + decoder_max_freqs: int = 8 + + +class FakeConfig: + # for diffusers compatability + def __init__(self): + self.patch_size = 1 + + +def _process_mask(attn_mask: Optional[torch.Tensor], dtype: torch.dtype): + if attn_mask is None: + return None + if attn_mask.ndim == 2: + attn_mask = attn_mask[:, None, None, :] + if attn_mask.dtype == torch.bool: + new_mask = torch.zeros_like(attn_mask, dtype=dtype) + new_mask.masked_fill_(~attn_mask, float("-inf")) + return new_mask + return attn_mask + + +def _native_attention_wrapper( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, +) -> torch.Tensor: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + attn_mask = _process_mask(attn_mask, query.dtype) + out = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) + return out.transpose(1, 2).contiguous() + + +class TimestepEmbedder(nn.Module): + def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + super().__init__() + if mid_size is None: + mid_size = out_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, mid_size, bias=True), + nn.SiLU(), + nn.Linear(mid_size, out_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + with torch.amp.autocast("cuda", enabled=False): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) + / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + weight_dtype = self.mlp[0].weight.dtype + if weight_dtype.is_floating_point: + t_freq = t_freq.to(weight_dtype) + return self.mlp(t_freq) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + return output * self.weight + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) + + +class ZImageAttention(nn.Module): + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: int, + qk_norm: bool = True, + eps: float = 1e-5, + ): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = dim // n_heads + + self.to_q = nn.Linear(dim, n_heads * self.head_dim, bias=False) + self.to_k = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False) + self.to_v = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False) + self.to_out = nn.ModuleList( + [nn.Linear(n_heads * self.head_dim, dim, bias=False)] + ) + + self.norm_q = RMSNorm(self.head_dim, eps=eps) if qk_norm else None + self.norm_k = RMSNorm(self.head_dim, eps=eps) if qk_norm else None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + query = query.unflatten(-1, (self.n_heads, -1)) + key = key.unflatten(-1, (self.n_kv_heads, -1)) + value = value.unflatten(-1, (self.n_kv_heads, -1)) + + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + hidden_states = _native_attention_wrapper( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.flatten(2, 3).to(dtype) + return self.to_out[0](hidden_states) + + +class ZImageTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation: bool = True, + adaln_embed_dim: int = 256, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + self.layer_id = layer_id + self.modulation = modulation + + self.attention = ZImageAttention(dim, n_heads, n_kv_heads, qk_norm, norm_eps) + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + if modulation: + self.adaLN_modulation = nn.ModuleList( + [nn.Linear(min(dim, adaln_embed_dim), 4 * dim, bias=True)] + ) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + ): + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = ( + self.adaLN_modulation[0](adaln_input).unsqueeze(1).chunk(4, dim=2) + ) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, + attention_mask=attn_mask, + freqs_cis=freqs_cis, + ) + x = x + gate_msa * self.attention_norm2(attn_out) + x = x + gate_mlp * self.ffn_norm2( + self.feed_forward(self.ffn_norm1(x) * scale_mlp) + ) + else: + attn_out = self.attention( + self.attention_norm1(x), + attention_mask=attn_mask, + freqs_cis=freqs_cis, + ) + x = x + self.attention_norm2(attn_out) + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + return x + + +class RopeEmbedder: + def __init__( + self, + theta: float = 256, + axes_dims: List[int] = None, + axes_lens: List[int] = None, + ): + self.theta = theta + self.axes_dims = axes_dims or [32, 48, 48] + self.axes_lens = axes_lens or [1536, 512, 512] + assert len(self.axes_dims) == len(self.axes_lens) + self.freqs_cis = None + + @staticmethod + def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256): + with torch.device("cpu"): + freqs_cis = [] + for d, e in zip(dim, end): + freqs = 1.0 / ( + theta + ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d) + ) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to( + torch.complex64 + ) + freqs_cis.append(freqs_cis_i) + return freqs_cis + + def __call__(self, ids: torch.Tensor): + assert ids.ndim >= 2 and ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis( + self.axes_dims, self.axes_lens, theta=self.theta + ) + self.freqs_cis = [f.to(device) for f in self.freqs_cis] + elif self.freqs_cis[0].device != device: + self.freqs_cis = [f.to(device) for f in self.freqs_cis] + + return torch.cat( + [self.freqs_cis[i][ids[..., i]] for i in range(len(self.axes_dims))], dim=-1 + ) + + +# --- Decoder components --- + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + + +class NerfEmbedder(nn.Module): + def __init__(self, in_channels, hidden_size_input, max_freqs): + super().__init__() + self.max_freqs = max_freqs + self.hidden_size_input = hidden_size_input + self.embedder = nn.Sequential( + nn.Linear(in_channels + max_freqs**2, hidden_size_input) + ) + + @lru_cache(maxsize=4) + def fetch_pos(self, patch_size, device, dtype): + pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") + + pos_x = pos_x.reshape(-1, 1, 1) + pos_y = pos_y.reshape(-1, 1, 1) + + freqs = torch.linspace( + 0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device + ) + freqs_x = freqs[None, :, None] + freqs_y = freqs[None, None, :] + + coeffs = (1 + freqs_x * freqs_y) ** -1 + dct_x = torch.cos(pos_x * freqs_x * torch.pi) + dct_y = torch.cos(pos_y * freqs_y * torch.pi) + dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs**2) + return dct + + def forward(self, inputs): + B, P2, C = inputs.shape + original_dtype = inputs.dtype + with torch.autocast("cuda", enabled=False): + patch_size = int(P2**0.5) + inputs = inputs.float() + dct = self.fetch_pos(patch_size, inputs.device, torch.float32) + dct = dct.repeat(B, 1, 1) + inputs = torch.cat([inputs, dct], dim=-1) + inputs = self.embedder.float()(inputs) + return inputs.to(original_dtype) + + +class ResBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.in_ln = nn.LayerNorm(channels, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(channels, channels, bias=True), + nn.SiLU(), + nn.Linear(channels, channels, bias=True), + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 3 * channels, bias=True), + ) + self._init_weights() + + def _init_weights(self): + for m in self.mlp: + if isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="linear") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + + def forward(self, x, y): + shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) + h = modulate(self.in_ln(x), shift_mlp, scale_mlp) + h = self.mlp(h) + return x + gate_mlp * h + + +class DCTFinalLayer(nn.Module): + def __init__(self, model_channels, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm( + model_channels, elementwise_affine=False, eps=1e-6 + ) + self.linear = nn.Linear(model_channels, out_channels, bias=True) + nn.init.constant_(self.linear.weight, 0) + nn.init.constant_(self.linear.bias, 0) + + def forward(self, x): + return self.linear(self.norm_final(x)) + + +class SimpleMLPAdaLN(nn.Module): + def __init__( + self, + in_channels, + model_channels, + out_channels, + z_channels, + num_res_blocks, + patch_size, + max_freqs=8, + ): + super().__init__() + self.patch_size = patch_size + self.cond_embed = nn.Linear(z_channels, patch_size**2 * model_channels) + self.input_embedder = NerfEmbedder( + in_channels=in_channels, + hidden_size_input=model_channels, + max_freqs=max_freqs, + ) + self.res_blocks = nn.ModuleList( + [ResBlock(model_channels) for _ in range(num_res_blocks)] + ) + self.final_layer = DCTFinalLayer(model_channels, out_channels) + nn.init.xavier_uniform_(self.cond_embed.weight) + nn.init.constant_(self.cond_embed.bias, 0) + + def forward(self, x, c): + x = self.input_embedder(x) + c = self.cond_embed(c) + y = c.reshape(c.shape[0], self.patch_size**2, -1) + for block in self.res_blocks: + x = block(x, y) + return self.final_layer(x) + + +class ZImageDCT(nn.Module): + def __init__(self, params: ZImageDCTParams): + super().__init__() + self.config = FakeConfig() + self.in_channels = params.in_channels + self.out_channels = params.in_channels + self.patch_size = params.patch_size + self.f_patch_size = params.f_patch_size + self.dim = params.dim + self.n_heads = params.n_heads + self.rope_theta = params.rope_theta + self.t_scale = params.t_scale + self.adaln_embed_dim = params.adaln_embed_dim + + self.x_embedder = nn.Linear( + self.f_patch_size * self.patch_size * self.patch_size * params.in_channels, + params.dim, + bias=True, + ) + + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + i, + params.dim, + params.n_heads, + params.n_kv_heads, + params.norm_eps, + params.qk_norm, + modulation=True, + adaln_embed_dim=params.adaln_embed_dim, + ) + for i in range(params.n_refiner_layers) + ] + ) + + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + i, + params.dim, + params.n_heads, + params.n_kv_heads, + params.norm_eps, + params.qk_norm, + modulation=False, + ) + for i in range(params.n_refiner_layers) + ] + ) + + self.t_embedder = TimestepEmbedder( + min(params.dim, params.adaln_embed_dim), mid_size=1024 + ) + + self.cap_embedder = nn.Sequential( + RMSNorm(params.cap_feat_dim, eps=params.norm_eps), + nn.Linear(params.cap_feat_dim, params.dim, bias=True), + ) + + self.x_pad_token = nn.Parameter(torch.empty((1, params.dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, params.dim))) + + self.layers = nn.ModuleList( + [ + ZImageTransformerBlock( + i, + params.dim, + params.n_heads, + params.n_kv_heads, + params.norm_eps, + params.qk_norm, + modulation=True, + adaln_embed_dim=params.adaln_embed_dim, + ) + for i in range(params.n_layers) + ] + ) + + head_dim = params.dim // params.n_heads + assert head_dim == sum(params.axes_dims) + self.axes_dims = params.axes_dims + self.axes_lens = params.axes_lens + + self.rope_embedder = RopeEmbedder( + theta=params.rope_theta, + axes_dims=params.axes_dims, + axes_lens=params.axes_lens, + ) + + self.dec_net = SimpleMLPAdaLN( + in_channels=params.in_channels, + model_channels=params.decoder_hidden_size, + out_channels=params.in_channels, + z_channels=params.dim, + num_res_blocks=params.decoder_num_res_blocks, + patch_size=self.patch_size, + max_freqs=params.decoder_max_freqs, + ) + + if params.use_x0: + self.register_buffer("__x0__", torch.tensor([])) + + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def _forward( + self, + img: Tensor, + img_ids: Tensor, + img_mask: Tensor, + txt: Tensor, + txt_ids: Tensor, + txt_mask: Tensor, + timesteps: Tensor, + ): + B = img.shape[0] + num_patches = img.shape[1] + + pixel_values = img.reshape( + B * num_patches, self.patch_size**2, self.in_channels + ) + + timesteps = (1 - timesteps) * self.t_scale + timesteps_embedding = self.t_embedder(timesteps) + + img_hidden = self.x_embedder(img) + txt_hidden = self.cap_embedder(txt) + + img_pe = self.rope_embedder(img_ids) + txt_pe = self.rope_embedder(txt_ids) + + for layer in self.noise_refiner: + img_hidden = layer(img_hidden, img_mask, img_pe, timesteps_embedding) + + for layer in self.context_refiner: + txt_hidden = layer(txt_hidden, txt_mask, txt_pe) + + mixed_hidden = torch.cat((txt_hidden, img_hidden), 1) + mixed_mask = torch.cat((txt_mask, img_mask), 1) + mixed_pe = torch.cat((txt_pe, img_pe), 1) + + for layer in self.layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + mixed_hidden = ckpt.checkpoint( + layer, + mixed_hidden, + mixed_mask, + mixed_pe, + timesteps_embedding, + use_reentrant=False, + ) + else: + mixed_hidden = layer( + mixed_hidden, mixed_mask, mixed_pe, timesteps_embedding + ) + + img_hidden = mixed_hidden[:, txt.shape[1] :, ...] + + decoder_condition = img_hidden.reshape(B * num_patches, self.dim) + output = self.dec_net(pixel_values, decoder_condition) + output = output.reshape(B, num_patches, -1) + + return -output + + def _apply_x0_residual(self, predicted, noisy, timesteps): + return (noisy - predicted) / timesteps.view(-1, 1, 1) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + img_mask: Tensor, + txt: Tensor, + txt_ids: Tensor, + txt_mask: Tensor, + timesteps: Tensor, + ): + out = self._forward( + img=img, + img_ids=img_ids, + img_mask=img_mask, + txt=txt, + txt_ids=txt_ids, + txt_mask=txt_mask, + timesteps=timesteps, + ) + if hasattr(self, "__x0__"): + return self._apply_x0_residual(out, img, timesteps) + return out + + +def vae_flatten(latents, patch_size=2): + """Patchify: [N, C, H, W] -> ([N, num_patches, patch_size*patch_size*C], original_shape)""" + return ( + rearrange( + latents, + "n c (h dh) (w dw) -> n (h w) (dh dw c)", + dh=patch_size, + dw=patch_size, + ), + latents.shape, + ) + + +def vae_unflatten(latents, shape, patch_size=2): + """Unpatchify: [N, num_patches, patch_size*patch_size*C] -> [N, C, H, W]""" + n, c, h, w = shape + return rearrange( + latents, + "n (h w) (dh dw c) -> n c (h dh) (w dw)", + dh=patch_size, + dw=patch_size, + c=c, + h=h // patch_size, + w=w // patch_size, + ) + + +def prepare_latent_image_ids(start_indices, height, width, patch_size=2, max_offset=0): + """Generate 3D positional IDs for image patches.""" + if isinstance(start_indices, list): + start_indices = torch.tensor(start_indices) + + batch_size = len(start_indices) + latent_image_ids = torch.zeros(height // patch_size, width // patch_size, 3) + latent_image_ids[..., 1] = ( + latent_image_ids[..., 1] + torch.arange(height // patch_size)[:, None] + ) + latent_image_ids[..., 2] = ( + latent_image_ids[..., 2] + torch.arange(width // patch_size)[None, :] + ) + + h, w, ch = latent_image_ids.shape + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + + for i, start_idx in enumerate(start_indices): + latent_image_ids[i, :, :, 0] = start_idx + + return latent_image_ids.reshape(batch_size, h * w, ch).int() + + +def make_text_position_ids(valid_len, max_sequence_length, extra_padding=0): + """Generate 3D positional IDs for text tokens.""" + device = valid_len.device + valid_len = valid_len + extra_padding + B = valid_len.shape[0] + seq = ( + torch.arange(1, max_sequence_length + 1, device=device) + .unsqueeze(0) + .expand(B, -1) + ) + increment_then_repeat = torch.minimum(seq, valid_len.unsqueeze(1)) + pos_ids = torch.zeros((B, max_sequence_length, 3), device=device) + pos_ids[:, :, 0] = increment_then_repeat + return pos_ids.int() + + + +def time_shift(mu: float, sigma: float, t: Tensor) -> Tensor: + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list: + """Build a shifted cosine timestep schedule from t=1 (noise) to t=0 (clean).""" + timesteps = torch.linspace(1, 0, num_steps + 1) + if shift: + m = (max_shift - base_shift) / (4096 - 256) + b = base_shift - m * 256 + mu = m * image_seq_len + b + timesteps = time_shift(mu, 1.0, timesteps) + return timesteps.tolist() diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index c82d3ad3..e4001cc8 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -29,7 +29,8 @@ type AdditionalSections = | 'model.low_vram' | 'model.qie.match_target_res' | 'model.assistant_lora_path'; -type ModelGroup = 'image' | 'instruction' | 'video'; + +type ModelGroup = 'image' | 'instruction' | 'video' | 'experimental'; export interface ModelArch { name: string; @@ -133,6 +134,21 @@ export const modelArchs: ModelArch[] = [ }, disableSections: ['network.conv'], }, + { + name: 'zeta_chroma', + label: 'Zeta Chroma', + group: 'experimental', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['lodestones/Zeta-Chroma/zeta-chroma-base-x0-pixel-dino-distance.safetensors', defaultNameOrPath], + 'config.process[0].model.extras_name_or_path': ['Tongyi-MAI/Z-Image-Turbo', undefined], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + }, + disableSections: ['network.conv'], + }, { name: 'wan21:1b', label: 'Wan 2.1 (1.3B)', diff --git a/version.py b/version.py index 467864e1..9d39d73a 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.7.23" +VERSION = "0.7.24"