diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index 909c3ccb..6d20efa5 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -5,7 +5,7 @@ from .omnigen2 import OmniGen2Model from .flux_kontext import FluxKontextModel from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel -from .flux2 import Flux2Model +from .flux2 import Flux2Model, Flux2Klein4BModel, Flux2Klein9BModel from .z_image import ZImageModel from .ltx2 import LTX2Model @@ -27,4 +27,6 @@ AI_TOOLKIT_MODELS = [ Flux2Model, ZImageModel, LTX2Model, + Flux2Klein4BModel, + Flux2Klein9BModel, ] diff --git a/extensions_built_in/diffusion_models/flux2/__init__.py b/extensions_built_in/diffusion_models/flux2/__init__.py index 0f6a62ed..24ab4d91 100644 --- a/extensions_built_in/diffusion_models/flux2/__init__.py +++ b/extensions_built_in/diffusion_models/flux2/__init__.py @@ -1 +1,2 @@ -from .flux2_model import Flux2Model \ No newline at end of file +from .flux2_model import Flux2Model +from .flux2_klein_model import Flux2Klein4BModel, Flux2Klein9BModel diff --git a/extensions_built_in/diffusion_models/flux2/flux2_klein_model.py b/extensions_built_in/diffusion_models/flux2/flux2_klein_model.py new file mode 100644 index 00000000..86fdafa5 --- /dev/null +++ b/extensions_built_in/diffusion_models/flux2/flux2_klein_model.py @@ -0,0 +1,92 @@ +from .flux2_model import Flux2Model +from transformers import Qwen3ForCausalLM, Qwen2Tokenizer +from optimum.quanto import freeze +from toolkit.util.quantize import quantize, get_qtype +from toolkit.config_modules import ModelConfig +from toolkit.memory_management.manager import MemoryManager +from toolkit.basic import flush +from .src.model import Klein9BParams, Klein4BParams + + +class Flux2KleinModel(Flux2Model): + flux2_klein_te_path: str = None + flux2_te_type: str = "qwen" # "mistral" or "qwen" + flux2_vae_path: str = "ai-toolkit/flux2_vae" + flux2_is_guidance_distilled: bool = False + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, + model_config, + dtype, + custom_pipeline, + noise_scheduler, + **kwargs, + ) + # use the new format on this new model by default + self.use_old_lokr_format = False + + def load_te(self): + if self.flux2_klein_te_path is None: + raise ValueError("flux2_klein_te_path must be set for Flux2KleinModel") + dtype = self.torch_dtype + self.print_and_status_update("Loading Qwen3") + + text_encoder: Qwen3ForCausalLM = Qwen3ForCausalLM.from_pretrained( + self.flux2_klein_te_path, + torch_dtype=dtype, + ) + text_encoder.to(self.device_torch, dtype=dtype) + + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Qwen3") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype)) + freeze(text_encoder) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) + + tokenizer = Qwen2Tokenizer.from_pretrained(self.flux2_klein_te_path) + return text_encoder, tokenizer + + +class Flux2Klein4BModel(Flux2KleinModel): + arch = "flux2_klein_4b" + flux2_klein_te_path: str = "Qwen/Qwen3-4B" + flux2_te_filename: str = "flux-2-klein-base-4b.safetensors" + + def get_flux2_params(self): + return Klein4BParams() + + def get_base_model_version(self): + return "flux2_klein_4b" + + +class Flux2Klein9BModel(Flux2KleinModel): + arch = "flux2_klein_9b" + flux2_klein_te_path: str = "Qwen/Qwen3-8B" + flux2_te_filename: str = "flux-2-klein-base-9b.safetensors" + + def get_flux2_params(self): + return Klein9BParams() + + def get_base_model_version(self): + return "flux2_klein_9b" diff --git a/extensions_built_in/diffusion_models/flux2/flux2_model.py b/extensions_built_in/diffusion_models/flux2/flux2_model.py index 3ed3f482..726d0e40 100644 --- a/extensions_built_in/diffusion_models/flux2/flux2_model.py +++ b/extensions_built_in/diffusion_models/flux2/flux2_model.py @@ -55,6 +55,10 @@ HF_TOKEN = os.getenv("HF_TOKEN", None) class Flux2Model(BaseModel): arch = "flux2" + flux2_te_type: str = "mistral" # "mistral" or "qwen" + flux2_vae_path: str = None + flux2_te_filename: str = FLUX2_TRANSFORMER_FILENAME + flux2_is_guidance_distilled: bool = True def __init__( self, @@ -84,6 +88,42 @@ class Flux2Model(BaseModel): def get_bucket_divisibility(self): return 16 + def get_flux2_params(self): + return Flux2Params() + + def load_te(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading Mistral") + + text_encoder: Mistral3ForConditionalGeneration = ( + Mistral3ForConditionalGeneration.from_pretrained( + MISTRAL_PATH, + torch_dtype=dtype, + ) + ) + text_encoder.to(self.device_torch, dtype=dtype) + + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Mistral") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype)) + freeze(text_encoder) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) + + tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH) + return text_encoder, tokenizer + def load_model(self): dtype = self.torch_dtype self.print_and_status_update("Loading Flux2 model") @@ -93,19 +133,17 @@ class Flux2Model(BaseModel): self.print_and_status_update("Loading transformer") with torch.device("meta"): - transformer = Flux2(Flux2Params()) + transformer = Flux2(self.get_flux2_params()) # use local path if provided - if os.path.exists(os.path.join(transformer_path, FLUX2_TRANSFORMER_FILENAME)): - transformer_path = os.path.join( - transformer_path, FLUX2_TRANSFORMER_FILENAME - ) + if os.path.exists(os.path.join(transformer_path, self.flux2_te_filename)): + transformer_path = os.path.join(transformer_path, self.flux2_te_filename) if not os.path.exists(transformer_path): # assume it is from the hub transformer_path = huggingface_hub.hf_hub_download( repo_id=model_path, - filename=FLUX2_TRANSFORMER_FILENAME, + filename=self.flux2_te_filename, token=HF_TOKEN, ) @@ -143,35 +181,7 @@ class Flux2Model(BaseModel): self.print_and_status_update("Moving transformer to CPU") transformer.to("cpu") - self.print_and_status_update("Loading Mistral") - - text_encoder: Mistral3ForConditionalGeneration = ( - Mistral3ForConditionalGeneration.from_pretrained( - MISTRAL_PATH, - torch_dtype=dtype, - ) - ) - text_encoder.to(self.device_torch, dtype=dtype) - - flush() - - if self.model_config.quantize_te: - self.print_and_status_update("Quantizing Mistral") - quantize(text_encoder, weights=get_qtype(self.model_config.qtype)) - freeze(text_encoder) - flush() - - if ( - self.model_config.layer_offloading - and self.model_config.layer_offloading_text_encoder_percent > 0 - ): - MemoryManager.attach( - text_encoder, - self.device_torch, - offload_percent=self.model_config.layer_offloading_text_encoder_percent, - ) - - tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH) + text_encoder, tokenizer = self.load_te() self.print_and_status_update("Loading VAE") vae_path = self.model_config.vae_path @@ -179,10 +189,14 @@ class Flux2Model(BaseModel): if os.path.exists(os.path.join(model_path, FLUX2_VAE_FILENAME)): vae_path = os.path.join(model_path, FLUX2_VAE_FILENAME) + if vae_path is None: + vae_path = self.flux2_vae_path + if vae_path is None or not os.path.exists(vae_path): + p = vae_path if vae_path is not None else model_path # assume it is from the hub vae_path = huggingface_hub.hf_hub_download( - repo_id=model_path, + repo_id=p, filename=FLUX2_VAE_FILENAME, token=HF_TOKEN, ) @@ -207,6 +221,8 @@ class Flux2Model(BaseModel): tokenizer=tokenizer, vae=vae, transformer=None, + text_encoder_type=self.flux2_te_type, + is_guidance_distilled=self.flux2_is_guidance_distilled, ) # for quantization, it works best to do these after making the pipe pipe.transformer = transformer @@ -241,6 +257,8 @@ class Flux2Model(BaseModel): tokenizer=self.tokenizer[0], vae=unwrap_model(self.vae), transformer=unwrap_model(self.transformer), + text_encoder_type=self.flux2_te_type, + is_guidance_distilled=self.flux2_is_guidance_distilled, ) pipeline = pipeline.to(self.device_torch) @@ -281,6 +299,9 @@ class Flux2Model(BaseModel): control_img = control_img.convert("RGB") control_img_list.append(control_img) + if not self.flux2_is_guidance_distilled: + extra["negative_prompt_embeds"] = unconditional_embeds.text_embeds + img = pipeline( prompt_embeds=conditional_embeds.text_embeds, height=gen_config.height, diff --git a/extensions_built_in/diffusion_models/flux2/src/model.py b/extensions_built_in/diffusion_models/flux2/src/model.py index a78f6a28..211b438c 100644 --- a/extensions_built_in/diffusion_models/flux2/src/model.py +++ b/extensions_built_in/diffusion_models/flux2/src/model.py @@ -17,6 +17,35 @@ class Flux2Params: axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32]) theta: int = 2000 mlp_ratio: float = 3.0 + use_guidance_embed: bool = True + + +@dataclass +class Klein9BParams: + in_channels: int = 128 + context_in_dim: int = 12288 + hidden_size: int = 4096 + num_heads: int = 32 + depth: int = 8 + depth_single_blocks: int = 24 + axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32]) + theta: int = 2000 + mlp_ratio: float = 3.0 + use_guidance_embed: bool = False + + +@dataclass +class Klein4BParams: + in_channels: int = 128 + context_in_dim: int = 7680 + hidden_size: int = 3072 + num_heads: int = 24 + depth: int = 5 + depth_single_blocks: int = 20 + axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32]) + theta: int = 2000 + mlp_ratio: float = 3.0 + use_guidance_embed: bool = False class FakeConfig: @@ -50,11 +79,14 @@ class Flux2(nn.Module): self.time_in = MLPEmbedder( in_dim=256, hidden_dim=self.hidden_size, disable_bias=True ) - self.guidance_in = MLPEmbedder( - in_dim=256, hidden_dim=self.hidden_size, disable_bias=True - ) self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False) + self.use_guidance_embed = params.use_guidance_embed + if self.use_guidance_embed: + self.guidance_in = MLPEmbedder( + in_dim=256, hidden_dim=self.hidden_size, disable_bias=True + ) + self.double_blocks = nn.ModuleList( [ DoubleStreamBlock( @@ -116,14 +148,15 @@ class Flux2(nn.Module): timesteps: Tensor, ctx: Tensor, ctx_ids: Tensor, - guidance: Tensor, + guidance: Tensor | None, ): num_txt_tokens = ctx.shape[1] timestep_emb = timestep_embedding(timesteps, 256) vec = self.time_in(timestep_emb) - guidance_emb = timestep_embedding(guidance, 256) - vec = vec + self.guidance_in(guidance_emb) + if self.use_guidance_embed: + guidance_emb = timestep_embedding(guidance, 256) + vec = vec + self.guidance_in(guidance_emb) double_block_mod_img = self.double_stream_modulation_img(vec) double_block_mod_txt = self.double_stream_modulation_txt(vec) diff --git a/extensions_built_in/diffusion_models/flux2/src/pipeline.py b/extensions_built_in/diffusion_models/flux2/src/pipeline.py index ee0e8490..8e646393 100644 --- a/extensions_built_in/diffusion_models/flux2/src/pipeline.py +++ b/extensions_built_in/diffusion_models/flux2/src/pipeline.py @@ -4,8 +4,6 @@ import numpy as np import torch import PIL.Image from dataclasses import dataclass -from typing import List, Union - from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import ( @@ -13,14 +11,10 @@ from diffusers.utils import ( ) from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline - from diffusers.utils import BaseOutput - from .autoencoder import AutoEncoder from .model import Flux2 - from einops import rearrange - from transformers import AutoProcessor, Mistral3ForConditionalGeneration from .sampling import ( @@ -41,7 +35,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.""" -OUTPUT_LAYERS = [10, 20, 30] +OUTPUT_LAYERS_MISTRAL = [10, 20, 30] +OUTPUT_LAYERS_QWEN3 = [9, 18, 27] MAX_LENGTH = 512 @@ -56,6 +51,8 @@ class Flux2Pipeline(DiffusionPipeline): text_encoder: Mistral3ForConditionalGeneration, tokenizer: AutoProcessor, transformer: Flux2, + text_encoder_type: str = "mistral", # "mistral" or "qwen" + is_guidance_distilled: bool = False, ): super().__init__() @@ -70,6 +67,8 @@ class Flux2Pipeline(DiffusionPipeline): self.num_channels_latents = 128 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = 64 + self.text_encoder_type = text_encoder_type + self.is_guidance_distilled = is_guidance_distilled def format_input( self, @@ -138,12 +137,66 @@ class Flux2Pipeline(DiffusionPipeline): use_cache=False, ) - out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1) + out = torch.stack( + [output.hidden_states[k] for k in OUTPUT_LAYERS_MISTRAL], dim=1 + ) prompt_embeds = rearrange(out, "b c l d -> b l (c d)") # they don't return attention mask, so we create it here return prompt_embeds, None + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 512, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + if not isinstance(prompt, list): + prompt = [prompt] + + all_input_ids = [] + all_attention_masks = [] + + for p in prompt: + messages = [{"role": "user", "content": p}] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + + model_inputs = self.tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(model_inputs["input_ids"]) + all_attention_masks.append(model_inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + output = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS_QWEN3], dim=1) + prompt_embeds = rearrange(out, "b c l d -> b l (c d)") + + # they dont use attention mask + return prompt_embeds, None + def encode_prompt( self, prompt: Union[str, List[str]], @@ -159,9 +212,18 @@ class Flux2Pipeline(DiffusionPipeline): batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_mistral_prompt_embeds( - prompt, device, max_sequence_length=max_sequence_length - ) + if self.text_encoder_type == "mistral": + prompt_embeds, prompt_embeds_mask = self._get_mistral_prompt_embeds( + prompt, device, max_sequence_length=max_sequence_length + ) + elif self.text_encoder_type == "qwen": + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, device, max_sequence_length=max_sequence_length + ) + else: + raise ValueError( + f"Unsupported text_encoder_type: {self.text_encoder_type}" + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -220,6 +282,7 @@ class Flux2Pipeline(DiffusionPipeline): def __call__( self, prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -229,6 +292,8 @@ class Flux2Pipeline(DiffusionPipeline): latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, max_sequence_length: int = 512, @@ -236,6 +301,11 @@ class Flux2Pipeline(DiffusionPipeline): ): height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + do_guidance = ( + guidance_scale is not None + and guidance_scale > 1.0 + and not self.is_guidance_distilled + ) self._guidance_scale = guidance_scale self._current_timestep = None @@ -263,6 +333,19 @@ class Flux2Pipeline(DiffusionPipeline): ) txt, txt_ids = batched_prc_txt(prompt_embeds) + neg_txt, neg_txt_ids = None, None + + if do_guidance: + negative_prompt_embeds, _ = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + neg_txt, neg_txt_ids = batched_prc_txt(negative_prompt_embeds) # 4. Prepare latent variables\ latents = self.prepare_latents( @@ -329,6 +412,17 @@ class Flux2Pipeline(DiffusionPipeline): guidance=guidance_vec, ) + if do_guidance: + pred_uncond = self.transformer( + x=img_input, + x_ids=img_input_ids, + timesteps=t_vec, + ctx=neg_txt, + ctx_ids=neg_txt_ids, + guidance=guidance_vec, + ) + pred = pred_uncond + guidance_scale * (pred - pred_uncond) + if img_cond_seq is not None: pred = pred[:, : packed_latents.shape[1]] diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index d10875fe..64fa7cfa 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -630,6 +630,68 @@ export const modelArchs: ModelArch[] = [ disableSections: ['network.conv'], additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch', 'datasets.do_i2v'], }, + { + name: 'flux2_klein_4b', + label: 'FLUX.2-klein-base-4B', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.2-klein-base-4B', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + 'config.process[0].model.model_kwargs': [ + { + match_target_res: false, + }, + {}, + ], + }, + disableSections: ['network.conv'], + additionalSections: [ + 'datasets.multi_control_paths', + 'sample.multi_ctrl_imgs', + 'model.low_vram', + 'model.layer_offloading', + 'model.qie.match_target_res', + ], + }, + { + name: 'flux2_klein_9b', + label: 'FLUX.2-klein-base-9B', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.2-klein-base-9B', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + 'config.process[0].model.model_kwargs': [ + { + match_target_res: false, + }, + {}, + ], + }, + disableSections: ['network.conv'], + additionalSections: [ + 'datasets.multi_control_paths', + 'sample.multi_ctrl_imgs', + 'model.low_vram', + 'model.layer_offloading', + 'model.qie.match_target_res', + ], + }, ].sort((a, b) => { // Sort by label, case-insensitive return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' }); diff --git a/version.py b/version.py index 82b3d72e..ac8c2387 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.7.19" +VERSION = "0.7.20"