diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index d871833a..31fb1ac5 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -6,6 +6,7 @@ from .flux_kontext import FluxKontextModel from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel from .flux2 import Flux2Model +from .z_image import ZImageModel AI_TOOLKIT_MODELS = [ # put a list of models here @@ -23,4 +24,5 @@ AI_TOOLKIT_MODELS = [ QwenImageEditModel, QwenImageEditPlusModel, Flux2Model, + ZImageModel, ] diff --git a/extensions_built_in/diffusion_models/z_image/__init__.py b/extensions_built_in/diffusion_models/z_image/__init__.py new file mode 100644 index 00000000..f95953da --- /dev/null +++ b/extensions_built_in/diffusion_models/z_image/__init__.py @@ -0,0 +1 @@ +from .z_image import ZImageModel \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/z_image/z_image.py b/extensions_built_in/diffusion_models/z_image/z_image.py new file mode 100644 index 00000000..29542635 --- /dev/null +++ b/extensions_built_in/diffusion_models/z_image/z_image.py @@ -0,0 +1,396 @@ +import os +from typing import List, Optional + +import huggingface_hub +import torch +import yaml +from toolkit.config_modules import GenerateImageConfig, ModelConfig, NetworkConfig +from toolkit.lora_special import LoRASpecialNetwork +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze +from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.memory_management import MemoryManager +from safetensors.torch import load_file + +from transformers import AutoTokenizer, Qwen3ForCausalLM +from diffusers import AutoencoderKL + +try: + from diffusers import ZImagePipeline + from diffusers.models.transformers import ZImageTransformer2DModel +except ImportError: + raise ImportError( + "Diffusers is out of date. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt" + ) + + +scheduler_config = { + "num_train_timesteps": 1000, + "use_dynamic_shifting": False, + "shift": 3.0, +} + + +class ZImageModel(BaseModel): + arch = "zimage" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["ZImageTransformer2DModel"] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return 16 * 2 # 16 for the VAE, 2 for patch size + + def load_training_adapter(self, transformer: ZImageTransformer2DModel): + self.print_and_status_update("Loading assistant LoRA") + lora_path = self.model_config.assistant_lora_path + if not os.path.exists(lora_path): + # assume it is a hub path + lora_splits = lora_path.split("/") + if len(lora_splits) != 3: + raise ValueError( + f"Assistant LoRA path {lora_path} is not a valid local path or hub path." + ) + repo_id = "/".join(lora_splits[:2]) + filename = lora_splits[2] + try: + lora_path = huggingface_hub.hf_hub_download( + repo_id=repo_id, + filename=filename, + ) + # upgrade path to + self.model_config.assistant_lora_path = lora_path + except Exception as e: + raise ValueError( + f"Failed to download assistant LoRA from {lora_path}: {e}" + ) + # load the adapter and merge it in. We will inference with a -1.0 multiplier so the adapter effects only work during training. + lora_state_dict = load_file(lora_path) + dim = int( + lora_state_dict[ + "diffusion_model.layers.0.attention.to_k.lora_A.weight" + ].shape[0] + ) + + new_sd = {} + for key, value in lora_state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + lora_state_dict = new_sd + + network_config = { + "type": "lora", + "linear": dim, + "linear_alpha": dim, + "transformer_only": True, + } + + network_config = NetworkConfig(**network_config) + LoRASpecialNetwork.LORA_PREFIX_UNET = "lora_transformer" + network = LoRASpecialNetwork( + text_encoder=None, + unet=transformer, + lora_dim=network_config.linear, + multiplier=1.0, + alpha=network_config.linear_alpha, + train_unet=True, + train_text_encoder=False, + network_config=network_config, + network_type=network_config.type, + transformer_only=network_config.transformer_only, + is_transformer=True, + target_lin_modules=self.target_lora_modules, + is_assistant_adapter=True, + ) + network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True) + self.print_and_status_update("Merging in assistant LoRA") + network.force_to(self.device_torch, dtype=self.torch_dtype) + network._update_torch_multiplier() + network.load_weights(lora_state_dict) + + network.merge_in(merge_weight=1.0) + + # mark it as not merged so inference ignores it. + network.is_merged_in = False + + # add the assistant so sampler will activate it while sampling + self.assistant_lora: LoRASpecialNetwork = network + + # deactivate lora during training + self.assistant_lora.multiplier = -1.0 + self.assistant_lora.is_active = False + + # tell the model to invert assistant on inference since we want remove lora effects + self.invert_assistant_lora = True + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading ZImage model") + model_path = self.model_config.name_or_path + base_model_path = self.model_config.extras_name_or_path + + self.print_and_status_update("Loading transformer") + + transformer_path = model_path + transformer_subfolder = "transformer" + if os.path.exists(transformer_path): + transformer_subfolder = None + transformer_path = os.path.join(transformer_path, "transformer") + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, "text_encoder") + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + transformer = ZImageTransformer2DModel.from_pretrained( + transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype + ) + + # load assistant lora if specified + if self.model_config.assistant_lora_path is not None: + self.load_training_adapter(transformer) + # set qtype to be float8 if it is qfloat8 + if self.model_config.qtype == "qfloat8": + self.model_config.qtype = "float8" + + if self.model_config.quantize: + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_transformer_percent > 0 + ): + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ) + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to("cpu") + + flush() + + self.print_and_status_update("Text Encoder") + tokenizer = AutoTokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", torch_dtype=dtype + ) + text_encoder = Qwen3ForCausalLM.from_pretrained( + base_model_path, subfolder="text_encoder", torch_dtype=dtype + ) + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Text Encoder") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + + self.print_and_status_update("Loading VAE") + vae = AutoencoderKL.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype + ) + + self.noise_scheduler = ZImageModel.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + kwargs = {} + + pipe: ZImagePipeline = ZImagePipeline( + scheduler=self.noise_scheduler, + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + transformer=None, + **kwargs, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder] + tokenizer = [pipe.tokenizer] + + # leave it on cpu for now + if not self.low_vram: + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = ZImageModel.get_train_scheduler() + + pipeline: ZImagePipeline = ZImagePipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: ZImagePipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + self.model.to(self.device_torch, dtype=self.torch_dtype) + self.model.to(self.device_torch) + + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra, + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs, + ): + self.model.to(self.device_torch) + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + timestep_model_input = (1000 - timestep) / 1000 + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + text_embeddings.text_embeds, + )[0] + + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + prompt_embeds, _ = self.pipeline.encode_prompt( + prompt, + do_classifier_free_guidance=False, + device=self.device_torch, + ) + pe = PromptEmbeds([prompt_embeds, None]) + return pe + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + transformer: ZImageTransformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, "transformer"), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, "aitk_meta.yaml") + with open(meta_path, "w") as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get("noise") + batch = kwargs.get("batch") + return (noise - batch.latents).detach() + + def get_base_model_version(self): + return "zimage" + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["layers"] + + def convert_lora_weights_before_save(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd diff --git a/requirements.txt b/requirements.txt index e5442be3..dd29dec4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ torchao==0.10.0 safetensors git+https://github.com/jaretburkett/easy_dwpose.git -git+https://github.com/huggingface/diffusers@1448b035859dd57bbb565239dcdd79a025a85422 -transformers==4.52.4 +git+https://github.com/huggingface/diffusers@6bf668c4d217ebc96065e673d8a257fd79950d34 +transformers==4.57.3 lycoris-lora==1.8.3 flatten_json pyyaml diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 31f18c1b..dd290d5f 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -805,7 +805,11 @@ class BaseModel: # check if batch size of embeddings matches batch size of latents if isinstance(text_embeddings.text_embeds, list): - te_batch_size = text_embeddings.text_embeds[0].shape[0] + if len(text_embeddings.text_embeds) == latents.shape[0]: + # handle list of embeddings + te_batch_size = len(text_embeddings.text_embeds) + else: + te_batch_size = text_embeddings.text_embeds[0].shape[0] else: te_batch_size = text_embeddings.text_embeds.shape[0] if latents.shape[0] == te_batch_size: diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 5421beb8..0cea3013 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -376,6 +376,11 @@ class ToolkitModuleMixin: if hasattr(self, 'scalar'): scale = scale * self.scalar + weight_device = weight.device + if weight.device != down_weight.device: + weight = weight.to(down_weight.device) + if scale.device != down_weight.device: + scale = scale.to(down_weight.device) # merge weight if self.full_rank: weight = weight + multiplier * down_weight * scale @@ -397,7 +402,7 @@ class ToolkitModuleMixin: weight = weight + multiplier * conved * scale # set weight to org_module - org_sd[weight_key] = weight.to(orig_dtype) + org_sd[weight_key] = weight.to(weight_device, orig_dtype) self.org_module[0].load_state_dict(org_sd) def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None): diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index b8a6f1e5..aaf0c94a 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -72,7 +72,10 @@ class PromptEmbeds: if self.pooled_embeds is not None: prompt_embeds = PromptEmbeds([cloned_text_embeds, self.pooled_embeds.clone()]) else: - prompt_embeds = PromptEmbeds(cloned_text_embeds) + if isinstance(cloned_text_embeds, list) or isinstance(cloned_text_embeds, tuple): + prompt_embeds = PromptEmbeds([cloned_text_embeds, None]) + else: + prompt_embeds = PromptEmbeds(cloned_text_embeds) if self.attention_mask is not None: if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple): diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index 093481ed..876b8ffd 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -205,6 +205,20 @@ export default function SimpleJob({ placeholder="" required /> + {modelArch?.additionalSections?.includes('model.assistant_lora_path') && ( + { + if (value?.trim() === '') { + value = undefined; + } + setJobConfig(value, 'config.process[0].model.assistant_lora_path'); + }} + placeholder="" + /> + )} {modelArch?.additionalSections?.includes('model.low_vram') && ( { // Sort by label, case-insensitive return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' }); diff --git a/ui/src/types.ts b/ui/src/types.ts index 80f0b782..7bc57dc7 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -160,6 +160,7 @@ export interface ModelConfig { layer_offloading?: boolean; layer_offloading_transformer_percent?: number; layer_offloading_text_encoder_percent?: number; + assistant_lora_path?: string; } export interface SampleItem { diff --git a/version.py b/version.py index c6746c8c..5a921ca6 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.7.5" \ No newline at end of file +VERSION = "0.7.6" \ No newline at end of file