mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Add support for training lodestones/Zeta-Chroma
This commit is contained in:
@@ -8,6 +8,7 @@ from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusMod
|
|||||||
from .flux2 import Flux2Model, Flux2Klein4BModel, Flux2Klein9BModel
|
from .flux2 import Flux2Model, Flux2Klein4BModel, Flux2Klein9BModel
|
||||||
from .z_image import ZImageModel
|
from .z_image import ZImageModel
|
||||||
from .ltx2 import LTX2Model
|
from .ltx2 import LTX2Model
|
||||||
|
from .zeta_chroma import ZetaChromaModel
|
||||||
|
|
||||||
AI_TOOLKIT_MODELS = [
|
AI_TOOLKIT_MODELS = [
|
||||||
# put a list of models here
|
# put a list of models here
|
||||||
@@ -29,4 +30,5 @@ AI_TOOLKIT_MODELS = [
|
|||||||
LTX2Model,
|
LTX2Model,
|
||||||
Flux2Klein4BModel,
|
Flux2Klein4BModel,
|
||||||
Flux2Klein9BModel,
|
Flux2Klein9BModel,
|
||||||
|
ZetaChromaModel,
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
from .zeta_chroma_model import ZetaChromaModel
|
||||||
@@ -0,0 +1,380 @@
|
|||||||
|
import os
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import huggingface_hub
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
from toolkit.config_modules import GenerateImageConfig, ModelConfig, NetworkConfig
|
||||||
|
from toolkit.lora_special import LoRASpecialNetwork
|
||||||
|
from toolkit.models.base_model import BaseModel
|
||||||
|
from toolkit.basic import flush
|
||||||
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
|
from toolkit.samplers.custom_flowmatch_sampler import (
|
||||||
|
CustomFlowMatchEulerDiscreteScheduler,
|
||||||
|
)
|
||||||
|
from toolkit.accelerator import unwrap_model
|
||||||
|
from optimum.quanto import freeze
|
||||||
|
from toolkit.util.quantize import quantize, get_qtype, quantize_model
|
||||||
|
from toolkit.memory_management import MemoryManager
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from optimum.quanto import QTensor
|
||||||
|
from toolkit.metadata import get_meta_for_safetensors
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from transformers import AutoTokenizer, Qwen3ForCausalLM
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
|
from toolkit.models.FakeVAE import FakeVAE
|
||||||
|
from .zeta_chroma_transformer import ZImageDCT, ZImageDCTParams, vae_flatten, vae_unflatten, prepare_latent_image_ids, make_text_position_ids
|
||||||
|
from .zeta_chroma_pipeline import ZetaChromaPipeline
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
scheduler_config = {
|
||||||
|
"num_train_timesteps": 1000,
|
||||||
|
"use_dynamic_shifting": False,
|
||||||
|
"shift": 3.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
ZETA_CHROMA_TRANSFORMER_FILENAME = "zeta-chroma-base-x0-pixel-dino-distance.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
class ZetaChromaModel(BaseModel):
|
||||||
|
arch = "zeta_chroma"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
dtype="bf16",
|
||||||
|
custom_pipeline=None,
|
||||||
|
noise_scheduler=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
|
||||||
|
)
|
||||||
|
self.is_flow_matching = True
|
||||||
|
self.is_transformer = True
|
||||||
|
self.target_lora_modules = ["ZImageDCT"]
|
||||||
|
self.patch_size = 32
|
||||||
|
self.max_sequence_length = 512
|
||||||
|
|
||||||
|
# static method to get the noise scheduler
|
||||||
|
@staticmethod
|
||||||
|
def get_train_scheduler():
|
||||||
|
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||||
|
|
||||||
|
def get_bucket_divisibility(self):
|
||||||
|
return self.patch_size
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
dtype = self.torch_dtype
|
||||||
|
self.print_and_status_update("Loading ZImage model")
|
||||||
|
model_path = self.model_config.name_or_path
|
||||||
|
base_model_path = self.model_config.extras_name_or_path
|
||||||
|
|
||||||
|
self.print_and_status_update("Loading transformer")
|
||||||
|
|
||||||
|
|
||||||
|
transformer_path = model_path
|
||||||
|
if not os.path.exists(transformer_path):
|
||||||
|
transformer_name = ZETA_CHROMA_TRANSFORMER_FILENAME
|
||||||
|
# if path ends with .safetensors, assume last part is filename
|
||||||
|
# this allows users to target different file names in the repo like
|
||||||
|
# lodestones/Zeta-Chroma/zeta-chroma-base-x0-pixel-dino-distance.safetensors
|
||||||
|
|
||||||
|
if transformer_path.endswith(".safetensors"):
|
||||||
|
splits = transformer_path.split("/")
|
||||||
|
transformer_name = splits[-1]
|
||||||
|
transformer_path = "/".join(splits[:-1])
|
||||||
|
# assume it is from the hub
|
||||||
|
transformer_path = huggingface_hub.hf_hub_download(
|
||||||
|
repo_id=transformer_path,
|
||||||
|
filename=transformer_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
transformer_state_dict = load_file(transformer_path, device="cpu")
|
||||||
|
|
||||||
|
# cast to dtype
|
||||||
|
for key in transformer_state_dict:
|
||||||
|
transformer_state_dict[key] = transformer_state_dict[key].to(dtype)
|
||||||
|
|
||||||
|
# Auto-detect use_x0 from checkpoint
|
||||||
|
use_x0 = "__x0__" in transformer_state_dict
|
||||||
|
|
||||||
|
# Build model params
|
||||||
|
in_channels = self.patch_size * self.patch_size * 3 # RGB patches
|
||||||
|
model_params = ZImageDCTParams(
|
||||||
|
patch_size=1,
|
||||||
|
in_channels=in_channels,
|
||||||
|
use_x0=use_x0,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.device("meta"):
|
||||||
|
transformer = ZImageDCT(model_params)
|
||||||
|
|
||||||
|
transformer.load_state_dict(transformer_state_dict, assign=True)
|
||||||
|
del transformer_state_dict
|
||||||
|
|
||||||
|
transformer.to(self.quantize_device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.model_config.quantize:
|
||||||
|
self.print_and_status_update("Quantizing Transformer")
|
||||||
|
quantize_model(self, transformer)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.model_config.layer_offloading
|
||||||
|
and self.model_config.layer_offloading_transformer_percent > 0
|
||||||
|
):
|
||||||
|
MemoryManager.attach(
|
||||||
|
transformer,
|
||||||
|
self.device_torch,
|
||||||
|
offload_percent=self.model_config.layer_offloading_transformer_percent,
|
||||||
|
ignore_modules=[
|
||||||
|
transformer.x_pad_token,
|
||||||
|
transformer.cap_pad_token,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.model_config.low_vram:
|
||||||
|
self.print_and_status_update("Moving transformer to CPU")
|
||||||
|
transformer.to("cpu")
|
||||||
|
|
||||||
|
flush()
|
||||||
|
|
||||||
|
self.print_and_status_update("Text Encoder")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
base_model_path, subfolder="tokenizer", torch_dtype=dtype
|
||||||
|
)
|
||||||
|
text_encoder = Qwen3ForCausalLM.from_pretrained(
|
||||||
|
base_model_path, subfolder="text_encoder", torch_dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.model_config.layer_offloading
|
||||||
|
and self.model_config.layer_offloading_text_encoder_percent > 0
|
||||||
|
):
|
||||||
|
MemoryManager.attach(
|
||||||
|
text_encoder,
|
||||||
|
self.device_torch,
|
||||||
|
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
|
||||||
|
)
|
||||||
|
|
||||||
|
text_encoder.to(self.device_torch, dtype=dtype)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
if self.model_config.quantize_te:
|
||||||
|
self.print_and_status_update("Quantizing Text Encoder")
|
||||||
|
quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te))
|
||||||
|
freeze(text_encoder)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
self.print_and_status_update("Loading VAE")
|
||||||
|
vae = FakeVAE(scaling_factor=1.0)
|
||||||
|
vae.to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
|
self.noise_scheduler = ZetaChromaModel.get_train_scheduler()
|
||||||
|
|
||||||
|
self.print_and_status_update("Making pipe")
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
pipe: ZetaChromaPipeline = ZetaChromaPipeline(
|
||||||
|
scheduler=self.noise_scheduler,
|
||||||
|
text_encoder=None,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
vae=vae,
|
||||||
|
transformer=None,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# for quantization, it works best to do these after making the pipe
|
||||||
|
pipe.text_encoder = text_encoder
|
||||||
|
pipe.transformer = transformer
|
||||||
|
|
||||||
|
self.print_and_status_update("Preparing Model")
|
||||||
|
|
||||||
|
text_encoder = [pipe.text_encoder]
|
||||||
|
tokenizer = [pipe.tokenizer]
|
||||||
|
|
||||||
|
# leave it on cpu for now
|
||||||
|
if not self.low_vram:
|
||||||
|
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||||
|
|
||||||
|
flush()
|
||||||
|
# just to make sure everything is on the right device and dtype
|
||||||
|
text_encoder[0].to(self.device_torch)
|
||||||
|
text_encoder[0].requires_grad_(False)
|
||||||
|
text_encoder[0].eval()
|
||||||
|
flush()
|
||||||
|
|
||||||
|
# save it to the model class
|
||||||
|
self.vae = vae
|
||||||
|
self.text_encoder = text_encoder # list of text encoders
|
||||||
|
self.tokenizer = tokenizer # list of tokenizers
|
||||||
|
self.model = pipe.transformer
|
||||||
|
self.pipeline = pipe
|
||||||
|
self.print_and_status_update("Model Loaded")
|
||||||
|
|
||||||
|
def get_generation_pipeline(self):
|
||||||
|
scheduler = ZetaChromaModel.get_train_scheduler()
|
||||||
|
|
||||||
|
pipeline: ZetaChromaPipeline = ZetaChromaPipeline(
|
||||||
|
scheduler=scheduler,
|
||||||
|
text_encoder=unwrap_model(self.text_encoder[0]),
|
||||||
|
tokenizer=self.tokenizer[0],
|
||||||
|
vae=unwrap_model(self.vae),
|
||||||
|
transformer=unwrap_model(self.transformer),
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = pipeline.to(self.device_torch)
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
def generate_single_image(
|
||||||
|
self,
|
||||||
|
pipeline: ZetaChromaPipeline,
|
||||||
|
gen_config: GenerateImageConfig,
|
||||||
|
conditional_embeds: PromptEmbeds,
|
||||||
|
unconditional_embeds: PromptEmbeds,
|
||||||
|
generator: torch.Generator,
|
||||||
|
extra: dict,
|
||||||
|
):
|
||||||
|
self.model.to(self.device_torch, dtype=self.torch_dtype)
|
||||||
|
self.model.to(self.device_torch)
|
||||||
|
|
||||||
|
sc = self.get_bucket_divisibility()
|
||||||
|
gen_config.width = int(gen_config.width // sc * sc)
|
||||||
|
gen_config.height = int(gen_config.height // sc * sc)
|
||||||
|
img = pipeline(
|
||||||
|
prompt_embeds=conditional_embeds.text_embeds,
|
||||||
|
prompt_embeds_mask=conditional_embeds.attention_mask,
|
||||||
|
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||||
|
negative_prompt_embeds_mask=unconditional_embeds.attention_mask,
|
||||||
|
height=gen_config.height,
|
||||||
|
width=gen_config.width,
|
||||||
|
num_inference_steps=gen_config.num_inference_steps,
|
||||||
|
guidance_scale=gen_config.guidance_scale,
|
||||||
|
latents=gen_config.latents,
|
||||||
|
generator=generator,
|
||||||
|
**extra,
|
||||||
|
).images[0]
|
||||||
|
return img
|
||||||
|
|
||||||
|
def get_noise_prediction(
|
||||||
|
self,
|
||||||
|
latent_model_input: torch.Tensor,
|
||||||
|
timestep: torch.Tensor, # 0 to 1000 scale
|
||||||
|
text_embeddings: PromptEmbeds,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if self.model.device == torch.device("cpu"):
|
||||||
|
self.model.to(self.device_torch)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
pixel_shape = latent_model_input.shape
|
||||||
|
# todo: do we invert like this?
|
||||||
|
# t_vec = (1000 - timestep) / 1000
|
||||||
|
t_vec = timestep / 1000
|
||||||
|
|
||||||
|
height = latent_model_input.shape[2]
|
||||||
|
h_patches = height // self.patch_size
|
||||||
|
width = latent_model_input.shape[3]
|
||||||
|
w_patches = width // self.patch_size
|
||||||
|
batch_size = latent_model_input.shape[0]
|
||||||
|
|
||||||
|
img, _ = vae_flatten(latent_model_input, patch_size=self.patch_size)
|
||||||
|
|
||||||
|
num_patches = img.shape[1]
|
||||||
|
|
||||||
|
# --- Build position IDs ---
|
||||||
|
pos_lengths = text_embeddings.attention_mask.sum(1)
|
||||||
|
offset = pos_lengths
|
||||||
|
|
||||||
|
image_pos_ids = prepare_latent_image_ids(
|
||||||
|
offset, h_patches, w_patches, patch_size=1
|
||||||
|
).to(self.device_torch)
|
||||||
|
pos_text_ids = make_text_position_ids(pos_lengths, self.max_sequence_length).to(
|
||||||
|
self.device_torch
|
||||||
|
)
|
||||||
|
img_mask = torch.ones(
|
||||||
|
(batch_size, num_patches), device=self.device_torch, dtype=torch.bool
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# model_out_list = self.transformer(
|
||||||
|
# latent_model_input_list,
|
||||||
|
# t_vec,
|
||||||
|
# text_embeddings.text_embeds,
|
||||||
|
# )[0]
|
||||||
|
pred = self.transformer(
|
||||||
|
img=img, #(1, 1024, 3072)
|
||||||
|
img_ids=image_pos_ids, # (1, 1024, 3)
|
||||||
|
img_mask=img_mask, # (1, 1024)
|
||||||
|
txt=text_embeddings.text_embeds, # (1, 512, 2560)
|
||||||
|
txt_ids=pos_text_ids, # (1, 512, 3)
|
||||||
|
txt_mask=text_embeddings.attention_mask, # (1, 512)
|
||||||
|
timesteps=t_vec, # (1,)
|
||||||
|
)
|
||||||
|
|
||||||
|
pred = vae_unflatten(pred.float(), pixel_shape, patch_size=self.patch_size)
|
||||||
|
|
||||||
|
return pred
|
||||||
|
|
||||||
|
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||||
|
if self.pipeline.text_encoder.device != self.device_torch:
|
||||||
|
self.pipeline.text_encoder.to(self.device_torch)
|
||||||
|
|
||||||
|
prompt_embeds, mask = self.pipeline._encode_prompts(
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
pe = PromptEmbeds([prompt_embeds, None], attention_mask=mask)
|
||||||
|
|
||||||
|
return pe
|
||||||
|
|
||||||
|
def get_model_has_grad(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_te_has_grad(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def save_model(self, output_path, meta, save_dtype):
|
||||||
|
if not output_path.endswith(".safetensors"):
|
||||||
|
output_path = output_path + ".safetensors"
|
||||||
|
# only save the unet
|
||||||
|
transformer: ZImageDCT = unwrap_model(self.model)
|
||||||
|
state_dict = transformer.state_dict()
|
||||||
|
save_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if isinstance(v, QTensor):
|
||||||
|
v = v.dequantize()
|
||||||
|
save_dict[k] = v.clone().to("cpu", dtype=save_dtype)
|
||||||
|
|
||||||
|
meta = get_meta_for_safetensors(meta, name="zeta_chroma")
|
||||||
|
save_file(save_dict, output_path, metadata=meta)
|
||||||
|
|
||||||
|
def get_loss_target(self, *args, **kwargs):
|
||||||
|
noise = kwargs.get("noise")
|
||||||
|
batch = kwargs.get("batch")
|
||||||
|
return (noise - batch.latents).detach()
|
||||||
|
|
||||||
|
def get_base_model_version(self):
|
||||||
|
return "zeta_chroma"
|
||||||
|
|
||||||
|
def get_transformer_block_names(self) -> Optional[List[str]]:
|
||||||
|
return ["layers"]
|
||||||
|
|
||||||
|
def convert_lora_weights_before_save(self, state_dict):
|
||||||
|
new_sd = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
new_key = key.replace("transformer.", "diffusion_model.")
|
||||||
|
new_sd[new_key] = value
|
||||||
|
return new_sd
|
||||||
|
|
||||||
|
def convert_lora_weights_before_load(self, state_dict):
|
||||||
|
new_sd = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
new_key = key.replace("diffusion_model.", "transformer.")
|
||||||
|
new_sd[new_key] = value
|
||||||
|
return new_sd
|
||||||
@@ -0,0 +1,180 @@
|
|||||||
|
from diffusers.pipelines.z_image.pipeline_z_image import (
|
||||||
|
ZImagePipeline,
|
||||||
|
calculate_shift,
|
||||||
|
retrieve_timesteps,
|
||||||
|
)
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
from diffusers.utils.torch_utils import randn_tensor
|
||||||
|
import torch
|
||||||
|
from diffusers.utils import logging, replace_example_docstring
|
||||||
|
from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput
|
||||||
|
from extensions_built_in.diffusion_models.zeta_chroma.zeta_chroma_transformer import (
|
||||||
|
get_schedule,
|
||||||
|
prepare_latent_image_ids,
|
||||||
|
make_text_position_ids,
|
||||||
|
vae_unflatten,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ZetaChromaPipeline(ZImagePipeline):
|
||||||
|
need_something_here = True
|
||||||
|
patch_size = 32
|
||||||
|
max_sequence_length = 512
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _encode_prompts(self, prompts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Encode a list of prompts with the Qwen3 chat template."""
|
||||||
|
formatted = []
|
||||||
|
for p in prompts:
|
||||||
|
messages = [{"role": "user", "content": p}]
|
||||||
|
text = self.tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=True,
|
||||||
|
)
|
||||||
|
formatted.append(text)
|
||||||
|
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
formatted,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=self.max_sequence_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.text_encoder.device)
|
||||||
|
|
||||||
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||||
|
outputs = self.text_encoder(
|
||||||
|
input_ids=inputs.input_ids,
|
||||||
|
attention_mask=inputs.attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second-to-last hidden state (same as training)
|
||||||
|
embeddings = outputs.hidden_states[-2]
|
||||||
|
mask = inputs.attention_mask.bool()
|
||||||
|
return embeddings, mask
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
guidance_scale: float = 5.0,
|
||||||
|
cfg_normalization: bool = False,
|
||||||
|
cfg_truncation: float = 1.0,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
|
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
prompt_embeds_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
negative_prompt_embeds_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||||
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
):
|
||||||
|
device = self._execution_device
|
||||||
|
|
||||||
|
batch_size = len(prompt_embeds)
|
||||||
|
device = self._execution_device
|
||||||
|
patch_size = self.patch_size
|
||||||
|
in_channels = patch_size * patch_size * 3
|
||||||
|
|
||||||
|
h_patches = height // patch_size
|
||||||
|
w_patches = width // patch_size
|
||||||
|
num_patches = h_patches * w_patches
|
||||||
|
|
||||||
|
pos_embeds, pos_mask = prompt_embeds, prompt_embeds_mask
|
||||||
|
neg_embeds, neg_mask = negative_prompt_embeds, negative_prompt_embeds_mask
|
||||||
|
|
||||||
|
# --- Build position IDs ---
|
||||||
|
pos_lengths = pos_mask.sum(1)
|
||||||
|
neg_lengths = neg_mask.sum(1)
|
||||||
|
offset = torch.maximum(pos_lengths, neg_lengths)
|
||||||
|
|
||||||
|
image_pos_ids = prepare_latent_image_ids(
|
||||||
|
offset, h_patches, w_patches, patch_size=1
|
||||||
|
).to(device)
|
||||||
|
pos_text_ids = make_text_position_ids(pos_lengths, max_sequence_length).to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
neg_text_ids = make_text_position_ids(neg_lengths, max_sequence_length).to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Initial noise ---
|
||||||
|
noise = randn_tensor(
|
||||||
|
(batch_size, num_patches, in_channels),
|
||||||
|
generator=generator,
|
||||||
|
device=device,
|
||||||
|
dtype=self.transformer.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Timestep schedule ---
|
||||||
|
timesteps = get_schedule(num_inference_steps, num_patches)
|
||||||
|
|
||||||
|
# --- Denoising loop (CFG) ---
|
||||||
|
img = noise
|
||||||
|
img_mask = torch.ones(
|
||||||
|
(batch_size, num_patches), device=device, dtype=torch.bool
|
||||||
|
)
|
||||||
|
self._num_timesteps = len(timesteps)
|
||||||
|
|
||||||
|
# 6. Denoising loop
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
|
||||||
|
|
||||||
|
t_vec = torch.full(
|
||||||
|
(batch_size,), t_curr, dtype=self.dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
pred = self.transformer(
|
||||||
|
img=img,
|
||||||
|
img_ids=image_pos_ids,
|
||||||
|
img_mask=img_mask,
|
||||||
|
txt=pos_embeds,
|
||||||
|
txt_ids=pos_text_ids,
|
||||||
|
txt_mask=pos_mask,
|
||||||
|
timesteps=t_vec,
|
||||||
|
)
|
||||||
|
|
||||||
|
if guidance_scale > 1.0:
|
||||||
|
pred_neg = self.transformer(
|
||||||
|
img=img,
|
||||||
|
img_ids=image_pos_ids,
|
||||||
|
img_mask=img_mask,
|
||||||
|
txt=neg_embeds,
|
||||||
|
txt_ids=neg_text_ids,
|
||||||
|
txt_mask=neg_mask,
|
||||||
|
timesteps=t_vec,
|
||||||
|
)
|
||||||
|
pred = pred_neg + guidance_scale * (pred - pred_neg)
|
||||||
|
|
||||||
|
img = img + (t_prev - t_curr) * pred
|
||||||
|
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
if output_type == "latent":
|
||||||
|
image = img
|
||||||
|
|
||||||
|
else:
|
||||||
|
# --- Unpatchify: [B, num_patches, C*P*P] -> [B, 3, H, W] ---
|
||||||
|
pixel_shape = (batch_size, 3, height, width)
|
||||||
|
image = vae_unflatten(img.float(), pixel_shape, patch_size=patch_size)
|
||||||
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||||
|
|
||||||
|
# Offload all models
|
||||||
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image,)
|
||||||
|
|
||||||
|
return ZImagePipelineOutput(images=image)
|
||||||
@@ -0,0 +1,739 @@
|
|||||||
|
# orig code provided by lodestones, altered for ai-toolkit
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import List, Optional
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
from einops import rearrange
|
||||||
|
import torch.utils.checkpoint as ckpt
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ZImageDCTParams:
|
||||||
|
patch_size: int = 1
|
||||||
|
f_patch_size: int = 1
|
||||||
|
in_channels: int = 128
|
||||||
|
dim: int = 3840
|
||||||
|
n_layers: int = 30
|
||||||
|
n_refiner_layers: int = 2
|
||||||
|
n_heads: int = 30
|
||||||
|
n_kv_heads: int = 30
|
||||||
|
norm_eps: float = 1e-5
|
||||||
|
qk_norm: bool = True
|
||||||
|
cap_feat_dim: int = 2560
|
||||||
|
rope_theta: int = 256
|
||||||
|
t_scale: float = 1000.0
|
||||||
|
axes_dims: list = field(default_factory=lambda: [32, 48, 48])
|
||||||
|
axes_lens: list = field(default_factory=lambda: [1536, 512, 512])
|
||||||
|
adaln_embed_dim: int = 256
|
||||||
|
use_x0: bool = True
|
||||||
|
# DCT decoder params
|
||||||
|
decoder_hidden_size: int = 3840
|
||||||
|
decoder_num_res_blocks: int = 4
|
||||||
|
decoder_max_freqs: int = 8
|
||||||
|
|
||||||
|
|
||||||
|
class FakeConfig:
|
||||||
|
# for diffusers compatability
|
||||||
|
def __init__(self):
|
||||||
|
self.patch_size = 1
|
||||||
|
|
||||||
|
|
||||||
|
def _process_mask(attn_mask: Optional[torch.Tensor], dtype: torch.dtype):
|
||||||
|
if attn_mask is None:
|
||||||
|
return None
|
||||||
|
if attn_mask.ndim == 2:
|
||||||
|
attn_mask = attn_mask[:, None, None, :]
|
||||||
|
if attn_mask.dtype == torch.bool:
|
||||||
|
new_mask = torch.zeros_like(attn_mask, dtype=dtype)
|
||||||
|
new_mask.masked_fill_(~attn_mask, float("-inf"))
|
||||||
|
return new_mask
|
||||||
|
return attn_mask
|
||||||
|
|
||||||
|
|
||||||
|
def _native_attention_wrapper(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
dropout_p: float = 0.0,
|
||||||
|
is_causal: bool = False,
|
||||||
|
scale: Optional[float] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
key = key.transpose(1, 2)
|
||||||
|
value = value.transpose(1, 2)
|
||||||
|
attn_mask = _process_mask(attn_mask, query.dtype)
|
||||||
|
out = F.scaled_dot_product_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
is_causal=is_causal,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
return out.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
|
||||||
|
super().__init__()
|
||||||
|
if mid_size is None:
|
||||||
|
mid_size = out_size
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(frequency_embedding_size, mid_size, bias=True),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(mid_size, out_size, bias=True),
|
||||||
|
)
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def timestep_embedding(t, dim, max_period=10000):
|
||||||
|
with torch.amp.autocast("cuda", enabled=False):
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(max_period)
|
||||||
|
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
||||||
|
/ half
|
||||||
|
)
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat(
|
||||||
|
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||||
|
)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def forward(self, t):
|
||||||
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||||
|
weight_dtype = self.mlp[0].weight.dtype
|
||||||
|
if weight_dtype.is_floating_point:
|
||||||
|
t_freq = t_freq.to(weight_dtype)
|
||||||
|
return self.mlp(t_freq)
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
return output * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim: int, hidden_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||||
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||||
|
with torch.amp.autocast("cuda", enabled=False):
|
||||||
|
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||||
|
freqs_cis = freqs_cis.unsqueeze(2)
|
||||||
|
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||||
|
return x_out.type_as(x_in)
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_kv_heads = n_kv_heads
|
||||||
|
self.head_dim = dim // n_heads
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(dim, n_heads * self.head_dim, bias=False)
|
||||||
|
self.to_k = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
|
||||||
|
self.to_v = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
|
||||||
|
self.to_out = nn.ModuleList(
|
||||||
|
[nn.Linear(n_heads * self.head_dim, dim, bias=False)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm_q = RMSNorm(self.head_dim, eps=eps) if qk_norm else None
|
||||||
|
self.norm_k = RMSNorm(self.head_dim, eps=eps) if qk_norm else None
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
freqs_cis: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
query = self.to_q(hidden_states)
|
||||||
|
key = self.to_k(hidden_states)
|
||||||
|
value = self.to_v(hidden_states)
|
||||||
|
|
||||||
|
query = query.unflatten(-1, (self.n_heads, -1))
|
||||||
|
key = key.unflatten(-1, (self.n_kv_heads, -1))
|
||||||
|
value = value.unflatten(-1, (self.n_kv_heads, -1))
|
||||||
|
|
||||||
|
if self.norm_q is not None:
|
||||||
|
query = self.norm_q(query)
|
||||||
|
if self.norm_k is not None:
|
||||||
|
key = self.norm_k(key)
|
||||||
|
|
||||||
|
if freqs_cis is not None:
|
||||||
|
query = apply_rotary_emb(query, freqs_cis)
|
||||||
|
key = apply_rotary_emb(key, freqs_cis)
|
||||||
|
|
||||||
|
dtype = query.dtype
|
||||||
|
query, key = query.to(dtype), key.to(dtype)
|
||||||
|
|
||||||
|
hidden_states = _native_attention_wrapper(
|
||||||
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states.flatten(2, 3).to(dtype)
|
||||||
|
return self.to_out[0](hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageTransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
norm_eps: float,
|
||||||
|
qk_norm: bool,
|
||||||
|
modulation: bool = True,
|
||||||
|
adaln_embed_dim: int = 256,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.head_dim = dim // n_heads
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.modulation = modulation
|
||||||
|
|
||||||
|
self.attention = ZImageAttention(dim, n_heads, n_kv_heads, qk_norm, norm_eps)
|
||||||
|
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
|
||||||
|
|
||||||
|
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||||
|
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||||
|
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||||
|
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||||
|
|
||||||
|
if modulation:
|
||||||
|
self.adaLN_modulation = nn.ModuleList(
|
||||||
|
[nn.Linear(min(dim, adaln_embed_dim), 4 * dim, bias=True)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attn_mask: torch.Tensor,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
adaln_input: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
if self.modulation:
|
||||||
|
assert adaln_input is not None
|
||||||
|
scale_msa, gate_msa, scale_mlp, gate_mlp = (
|
||||||
|
self.adaLN_modulation[0](adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||||
|
)
|
||||||
|
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||||
|
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||||
|
|
||||||
|
attn_out = self.attention(
|
||||||
|
self.attention_norm1(x) * scale_msa,
|
||||||
|
attention_mask=attn_mask,
|
||||||
|
freqs_cis=freqs_cis,
|
||||||
|
)
|
||||||
|
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||||
|
x = x + gate_mlp * self.ffn_norm2(
|
||||||
|
self.feed_forward(self.ffn_norm1(x) * scale_mlp)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_out = self.attention(
|
||||||
|
self.attention_norm1(x),
|
||||||
|
attention_mask=attn_mask,
|
||||||
|
freqs_cis=freqs_cis,
|
||||||
|
)
|
||||||
|
x = x + self.attention_norm2(attn_out)
|
||||||
|
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class RopeEmbedder:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
theta: float = 256,
|
||||||
|
axes_dims: List[int] = None,
|
||||||
|
axes_lens: List[int] = None,
|
||||||
|
):
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dims = axes_dims or [32, 48, 48]
|
||||||
|
self.axes_lens = axes_lens or [1536, 512, 512]
|
||||||
|
assert len(self.axes_dims) == len(self.axes_lens)
|
||||||
|
self.freqs_cis = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256):
|
||||||
|
with torch.device("cpu"):
|
||||||
|
freqs_cis = []
|
||||||
|
for d, e in zip(dim, end):
|
||||||
|
freqs = 1.0 / (
|
||||||
|
theta
|
||||||
|
** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)
|
||||||
|
)
|
||||||
|
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
|
||||||
|
freqs = torch.outer(timestep, freqs).float()
|
||||||
|
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(
|
||||||
|
torch.complex64
|
||||||
|
)
|
||||||
|
freqs_cis.append(freqs_cis_i)
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
def __call__(self, ids: torch.Tensor):
|
||||||
|
assert ids.ndim >= 2 and ids.shape[-1] == len(self.axes_dims)
|
||||||
|
device = ids.device
|
||||||
|
|
||||||
|
if self.freqs_cis is None:
|
||||||
|
self.freqs_cis = self.precompute_freqs_cis(
|
||||||
|
self.axes_dims, self.axes_lens, theta=self.theta
|
||||||
|
)
|
||||||
|
self.freqs_cis = [f.to(device) for f in self.freqs_cis]
|
||||||
|
elif self.freqs_cis[0].device != device:
|
||||||
|
self.freqs_cis = [f.to(device) for f in self.freqs_cis]
|
||||||
|
|
||||||
|
return torch.cat(
|
||||||
|
[self.freqs_cis[i][ids[..., i]] for i in range(len(self.axes_dims))], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Decoder components ---
|
||||||
|
|
||||||
|
|
||||||
|
def modulate(x, shift, scale):
|
||||||
|
return x * (1 + scale) + shift
|
||||||
|
|
||||||
|
|
||||||
|
class NerfEmbedder(nn.Module):
|
||||||
|
def __init__(self, in_channels, hidden_size_input, max_freqs):
|
||||||
|
super().__init__()
|
||||||
|
self.max_freqs = max_freqs
|
||||||
|
self.hidden_size_input = hidden_size_input
|
||||||
|
self.embedder = nn.Sequential(
|
||||||
|
nn.Linear(in_channels + max_freqs**2, hidden_size_input)
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache(maxsize=4)
|
||||||
|
def fetch_pos(self, patch_size, device, dtype):
|
||||||
|
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||||
|
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||||
|
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
|
||||||
|
|
||||||
|
pos_x = pos_x.reshape(-1, 1, 1)
|
||||||
|
pos_y = pos_y.reshape(-1, 1, 1)
|
||||||
|
|
||||||
|
freqs = torch.linspace(
|
||||||
|
0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
freqs_x = freqs[None, :, None]
|
||||||
|
freqs_y = freqs[None, None, :]
|
||||||
|
|
||||||
|
coeffs = (1 + freqs_x * freqs_y) ** -1
|
||||||
|
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
|
||||||
|
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
|
||||||
|
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs**2)
|
||||||
|
return dct
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
B, P2, C = inputs.shape
|
||||||
|
original_dtype = inputs.dtype
|
||||||
|
with torch.autocast("cuda", enabled=False):
|
||||||
|
patch_size = int(P2**0.5)
|
||||||
|
inputs = inputs.float()
|
||||||
|
dct = self.fetch_pos(patch_size, inputs.device, torch.float32)
|
||||||
|
dct = dct.repeat(B, 1, 1)
|
||||||
|
inputs = torch.cat([inputs, dct], dim=-1)
|
||||||
|
inputs = self.embedder.float()(inputs)
|
||||||
|
return inputs.to(original_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(self, channels):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.in_ln = nn.LayerNorm(channels, eps=1e-6)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(channels, channels, bias=True),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(channels, channels, bias=True),
|
||||||
|
)
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(channels, 3 * channels, bias=True),
|
||||||
|
)
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
for m in self.mlp:
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.kaiming_normal_(m.weight, nonlinearity="linear")
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
||||||
|
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
|
||||||
|
h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
|
||||||
|
h = self.mlp(h)
|
||||||
|
return x + gate_mlp * h
|
||||||
|
|
||||||
|
|
||||||
|
class DCTFinalLayer(nn.Module):
|
||||||
|
def __init__(self, model_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = nn.LayerNorm(
|
||||||
|
model_channels, elementwise_affine=False, eps=1e-6
|
||||||
|
)
|
||||||
|
self.linear = nn.Linear(model_channels, out_channels, bias=True)
|
||||||
|
nn.init.constant_(self.linear.weight, 0)
|
||||||
|
nn.init.constant_(self.linear.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(self.norm_final(x))
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleMLPAdaLN(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
model_channels,
|
||||||
|
out_channels,
|
||||||
|
z_channels,
|
||||||
|
num_res_blocks,
|
||||||
|
patch_size,
|
||||||
|
max_freqs=8,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.cond_embed = nn.Linear(z_channels, patch_size**2 * model_channels)
|
||||||
|
self.input_embedder = NerfEmbedder(
|
||||||
|
in_channels=in_channels,
|
||||||
|
hidden_size_input=model_channels,
|
||||||
|
max_freqs=max_freqs,
|
||||||
|
)
|
||||||
|
self.res_blocks = nn.ModuleList(
|
||||||
|
[ResBlock(model_channels) for _ in range(num_res_blocks)]
|
||||||
|
)
|
||||||
|
self.final_layer = DCTFinalLayer(model_channels, out_channels)
|
||||||
|
nn.init.xavier_uniform_(self.cond_embed.weight)
|
||||||
|
nn.init.constant_(self.cond_embed.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
x = self.input_embedder(x)
|
||||||
|
c = self.cond_embed(c)
|
||||||
|
y = c.reshape(c.shape[0], self.patch_size**2, -1)
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x, y)
|
||||||
|
return self.final_layer(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageDCT(nn.Module):
|
||||||
|
def __init__(self, params: ZImageDCTParams):
|
||||||
|
super().__init__()
|
||||||
|
self.config = FakeConfig()
|
||||||
|
self.in_channels = params.in_channels
|
||||||
|
self.out_channels = params.in_channels
|
||||||
|
self.patch_size = params.patch_size
|
||||||
|
self.f_patch_size = params.f_patch_size
|
||||||
|
self.dim = params.dim
|
||||||
|
self.n_heads = params.n_heads
|
||||||
|
self.rope_theta = params.rope_theta
|
||||||
|
self.t_scale = params.t_scale
|
||||||
|
self.adaln_embed_dim = params.adaln_embed_dim
|
||||||
|
|
||||||
|
self.x_embedder = nn.Linear(
|
||||||
|
self.f_patch_size * self.patch_size * self.patch_size * params.in_channels,
|
||||||
|
params.dim,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.noise_refiner = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ZImageTransformerBlock(
|
||||||
|
1000 + i,
|
||||||
|
params.dim,
|
||||||
|
params.n_heads,
|
||||||
|
params.n_kv_heads,
|
||||||
|
params.norm_eps,
|
||||||
|
params.qk_norm,
|
||||||
|
modulation=True,
|
||||||
|
adaln_embed_dim=params.adaln_embed_dim,
|
||||||
|
)
|
||||||
|
for i in range(params.n_refiner_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.context_refiner = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ZImageTransformerBlock(
|
||||||
|
i,
|
||||||
|
params.dim,
|
||||||
|
params.n_heads,
|
||||||
|
params.n_kv_heads,
|
||||||
|
params.norm_eps,
|
||||||
|
params.qk_norm,
|
||||||
|
modulation=False,
|
||||||
|
)
|
||||||
|
for i in range(params.n_refiner_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.t_embedder = TimestepEmbedder(
|
||||||
|
min(params.dim, params.adaln_embed_dim), mid_size=1024
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cap_embedder = nn.Sequential(
|
||||||
|
RMSNorm(params.cap_feat_dim, eps=params.norm_eps),
|
||||||
|
nn.Linear(params.cap_feat_dim, params.dim, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.x_pad_token = nn.Parameter(torch.empty((1, params.dim)))
|
||||||
|
self.cap_pad_token = nn.Parameter(torch.empty((1, params.dim)))
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ZImageTransformerBlock(
|
||||||
|
i,
|
||||||
|
params.dim,
|
||||||
|
params.n_heads,
|
||||||
|
params.n_kv_heads,
|
||||||
|
params.norm_eps,
|
||||||
|
params.qk_norm,
|
||||||
|
modulation=True,
|
||||||
|
adaln_embed_dim=params.adaln_embed_dim,
|
||||||
|
)
|
||||||
|
for i in range(params.n_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
head_dim = params.dim // params.n_heads
|
||||||
|
assert head_dim == sum(params.axes_dims)
|
||||||
|
self.axes_dims = params.axes_dims
|
||||||
|
self.axes_lens = params.axes_lens
|
||||||
|
|
||||||
|
self.rope_embedder = RopeEmbedder(
|
||||||
|
theta=params.rope_theta,
|
||||||
|
axes_dims=params.axes_dims,
|
||||||
|
axes_lens=params.axes_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dec_net = SimpleMLPAdaLN(
|
||||||
|
in_channels=params.in_channels,
|
||||||
|
model_channels=params.decoder_hidden_size,
|
||||||
|
out_channels=params.in_channels,
|
||||||
|
z_channels=params.dim,
|
||||||
|
num_res_blocks=params.decoder_num_res_blocks,
|
||||||
|
patch_size=self.patch_size,
|
||||||
|
max_freqs=params.decoder_max_freqs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.use_x0:
|
||||||
|
self.register_buffer("__x0__", torch.tensor([]))
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return next(self.parameters()).dtype
|
||||||
|
|
||||||
|
def enable_gradient_checkpointing(self):
|
||||||
|
self.gradient_checkpointing = True
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
img_mask: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
txt_mask: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
):
|
||||||
|
B = img.shape[0]
|
||||||
|
num_patches = img.shape[1]
|
||||||
|
|
||||||
|
pixel_values = img.reshape(
|
||||||
|
B * num_patches, self.patch_size**2, self.in_channels
|
||||||
|
)
|
||||||
|
|
||||||
|
timesteps = (1 - timesteps) * self.t_scale
|
||||||
|
timesteps_embedding = self.t_embedder(timesteps)
|
||||||
|
|
||||||
|
img_hidden = self.x_embedder(img)
|
||||||
|
txt_hidden = self.cap_embedder(txt)
|
||||||
|
|
||||||
|
img_pe = self.rope_embedder(img_ids)
|
||||||
|
txt_pe = self.rope_embedder(txt_ids)
|
||||||
|
|
||||||
|
for layer in self.noise_refiner:
|
||||||
|
img_hidden = layer(img_hidden, img_mask, img_pe, timesteps_embedding)
|
||||||
|
|
||||||
|
for layer in self.context_refiner:
|
||||||
|
txt_hidden = layer(txt_hidden, txt_mask, txt_pe)
|
||||||
|
|
||||||
|
mixed_hidden = torch.cat((txt_hidden, img_hidden), 1)
|
||||||
|
mixed_mask = torch.cat((txt_mask, img_mask), 1)
|
||||||
|
mixed_pe = torch.cat((txt_pe, img_pe), 1)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
|
mixed_hidden = ckpt.checkpoint(
|
||||||
|
layer,
|
||||||
|
mixed_hidden,
|
||||||
|
mixed_mask,
|
||||||
|
mixed_pe,
|
||||||
|
timesteps_embedding,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mixed_hidden = layer(
|
||||||
|
mixed_hidden, mixed_mask, mixed_pe, timesteps_embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
img_hidden = mixed_hidden[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
|
decoder_condition = img_hidden.reshape(B * num_patches, self.dim)
|
||||||
|
output = self.dec_net(pixel_values, decoder_condition)
|
||||||
|
output = output.reshape(B, num_patches, -1)
|
||||||
|
|
||||||
|
return -output
|
||||||
|
|
||||||
|
def _apply_x0_residual(self, predicted, noisy, timesteps):
|
||||||
|
return (noisy - predicted) / timesteps.view(-1, 1, 1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
img_mask: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
txt_mask: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
):
|
||||||
|
out = self._forward(
|
||||||
|
img=img,
|
||||||
|
img_ids=img_ids,
|
||||||
|
img_mask=img_mask,
|
||||||
|
txt=txt,
|
||||||
|
txt_ids=txt_ids,
|
||||||
|
txt_mask=txt_mask,
|
||||||
|
timesteps=timesteps,
|
||||||
|
)
|
||||||
|
if hasattr(self, "__x0__"):
|
||||||
|
return self._apply_x0_residual(out, img, timesteps)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def vae_flatten(latents, patch_size=2):
|
||||||
|
"""Patchify: [N, C, H, W] -> ([N, num_patches, patch_size*patch_size*C], original_shape)"""
|
||||||
|
return (
|
||||||
|
rearrange(
|
||||||
|
latents,
|
||||||
|
"n c (h dh) (w dw) -> n (h w) (dh dw c)",
|
||||||
|
dh=patch_size,
|
||||||
|
dw=patch_size,
|
||||||
|
),
|
||||||
|
latents.shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def vae_unflatten(latents, shape, patch_size=2):
|
||||||
|
"""Unpatchify: [N, num_patches, patch_size*patch_size*C] -> [N, C, H, W]"""
|
||||||
|
n, c, h, w = shape
|
||||||
|
return rearrange(
|
||||||
|
latents,
|
||||||
|
"n (h w) (dh dw c) -> n c (h dh) (w dw)",
|
||||||
|
dh=patch_size,
|
||||||
|
dw=patch_size,
|
||||||
|
c=c,
|
||||||
|
h=h // patch_size,
|
||||||
|
w=w // patch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_latent_image_ids(start_indices, height, width, patch_size=2, max_offset=0):
|
||||||
|
"""Generate 3D positional IDs for image patches."""
|
||||||
|
if isinstance(start_indices, list):
|
||||||
|
start_indices = torch.tensor(start_indices)
|
||||||
|
|
||||||
|
batch_size = len(start_indices)
|
||||||
|
latent_image_ids = torch.zeros(height // patch_size, width // patch_size, 3)
|
||||||
|
latent_image_ids[..., 1] = (
|
||||||
|
latent_image_ids[..., 1] + torch.arange(height // patch_size)[:, None]
|
||||||
|
)
|
||||||
|
latent_image_ids[..., 2] = (
|
||||||
|
latent_image_ids[..., 2] + torch.arange(width // patch_size)[None, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
h, w, ch = latent_image_ids.shape
|
||||||
|
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||||
|
|
||||||
|
for i, start_idx in enumerate(start_indices):
|
||||||
|
latent_image_ids[i, :, :, 0] = start_idx
|
||||||
|
|
||||||
|
return latent_image_ids.reshape(batch_size, h * w, ch).int()
|
||||||
|
|
||||||
|
|
||||||
|
def make_text_position_ids(valid_len, max_sequence_length, extra_padding=0):
|
||||||
|
"""Generate 3D positional IDs for text tokens."""
|
||||||
|
device = valid_len.device
|
||||||
|
valid_len = valid_len + extra_padding
|
||||||
|
B = valid_len.shape[0]
|
||||||
|
seq = (
|
||||||
|
torch.arange(1, max_sequence_length + 1, device=device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(B, -1)
|
||||||
|
)
|
||||||
|
increment_then_repeat = torch.minimum(seq, valid_len.unsqueeze(1))
|
||||||
|
pos_ids = torch.zeros((B, max_sequence_length, 3), device=device)
|
||||||
|
pos_ids[:, :, 0] = increment_then_repeat
|
||||||
|
return pos_ids.int()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def time_shift(mu: float, sigma: float, t: Tensor) -> Tensor:
|
||||||
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||||
|
|
||||||
|
|
||||||
|
def get_schedule(
|
||||||
|
num_steps: int,
|
||||||
|
image_seq_len: int,
|
||||||
|
base_shift: float = 0.5,
|
||||||
|
max_shift: float = 1.15,
|
||||||
|
shift: bool = True,
|
||||||
|
) -> list:
|
||||||
|
"""Build a shifted cosine timestep schedule from t=1 (noise) to t=0 (clean)."""
|
||||||
|
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||||
|
if shift:
|
||||||
|
m = (max_shift - base_shift) / (4096 - 256)
|
||||||
|
b = base_shift - m * 256
|
||||||
|
mu = m * image_seq_len + b
|
||||||
|
timesteps = time_shift(mu, 1.0, timesteps)
|
||||||
|
return timesteps.tolist()
|
||||||
@@ -29,7 +29,8 @@ type AdditionalSections =
|
|||||||
| 'model.low_vram'
|
| 'model.low_vram'
|
||||||
| 'model.qie.match_target_res'
|
| 'model.qie.match_target_res'
|
||||||
| 'model.assistant_lora_path';
|
| 'model.assistant_lora_path';
|
||||||
type ModelGroup = 'image' | 'instruction' | 'video';
|
|
||||||
|
type ModelGroup = 'image' | 'instruction' | 'video' | 'experimental';
|
||||||
|
|
||||||
export interface ModelArch {
|
export interface ModelArch {
|
||||||
name: string;
|
name: string;
|
||||||
@@ -133,6 +134,21 @@ export const modelArchs: ModelArch[] = [
|
|||||||
},
|
},
|
||||||
disableSections: ['network.conv'],
|
disableSections: ['network.conv'],
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: 'zeta_chroma',
|
||||||
|
label: 'Zeta Chroma',
|
||||||
|
group: 'experimental',
|
||||||
|
defaults: {
|
||||||
|
// default updates when [selected, unselected] in the UI
|
||||||
|
'config.process[0].model.name_or_path': ['lodestones/Zeta-Chroma/zeta-chroma-base-x0-pixel-dino-distance.safetensors', defaultNameOrPath],
|
||||||
|
'config.process[0].model.extras_name_or_path': ['Tongyi-MAI/Z-Image-Turbo', undefined],
|
||||||
|
'config.process[0].model.quantize': [true, false],
|
||||||
|
'config.process[0].model.quantize_te': [true, false],
|
||||||
|
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||||
|
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||||
|
},
|
||||||
|
disableSections: ['network.conv'],
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: 'wan21:1b',
|
name: 'wan21:1b',
|
||||||
label: 'Wan 2.1 (1.3B)',
|
label: 'Wan 2.1 (1.3B)',
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
VERSION = "0.7.23"
|
VERSION = "0.7.24"
|
||||||
|
|||||||
Reference in New Issue
Block a user