From af8e9ea149116f878191845c22e93bfefb15791c Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 18 Nov 2025 11:17:38 -0700 Subject: [PATCH 1/3] Add initial support for FLUX.2 --- .../diffusion_models/__init__.py | 2 + .../diffusion_models/flux2/__init__.py | 1 + .../diffusion_models/flux2/flux2_model.py | 490 +++++++++++++++++ .../diffusion_models/flux2/src/__init__.py | 0 .../diffusion_models/flux2/src/autoencoder.py | 370 +++++++++++++ .../diffusion_models/flux2/src/model.py | 520 ++++++++++++++++++ .../diffusion_models/flux2/src/pipeline.py | 358 ++++++++++++ .../diffusion_models/flux2/src/sampling.py | 365 ++++++++++++ ui/src/app/jobs/new/options.ts | 39 +- 9 files changed, 2144 insertions(+), 1 deletion(-) create mode 100644 extensions_built_in/diffusion_models/flux2/__init__.py create mode 100644 extensions_built_in/diffusion_models/flux2/flux2_model.py create mode 100644 extensions_built_in/diffusion_models/flux2/src/__init__.py create mode 100644 extensions_built_in/diffusion_models/flux2/src/autoencoder.py create mode 100644 extensions_built_in/diffusion_models/flux2/src/model.py create mode 100644 extensions_built_in/diffusion_models/flux2/src/pipeline.py create mode 100644 extensions_built_in/diffusion_models/flux2/src/sampling.py diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index f7d874e5..d871833a 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -5,6 +5,7 @@ from .omnigen2 import OmniGen2Model from .flux_kontext import FluxKontextModel from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel +from .flux2 import Flux2Model AI_TOOLKIT_MODELS = [ # put a list of models here @@ -21,4 +22,5 @@ AI_TOOLKIT_MODELS = [ QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel, + Flux2Model, ] diff --git a/extensions_built_in/diffusion_models/flux2/__init__.py b/extensions_built_in/diffusion_models/flux2/__init__.py new file mode 100644 index 00000000..0f6a62ed --- /dev/null +++ b/extensions_built_in/diffusion_models/flux2/__init__.py @@ -0,0 +1 @@ +from .flux2_model import Flux2Model \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/flux2/flux2_model.py b/extensions_built_in/diffusion_models/flux2/flux2_model.py new file mode 100644 index 00000000..11154e4d --- /dev/null +++ b/extensions_built_in/diffusion_models/flux2/flux2_model.py @@ -0,0 +1,490 @@ +import math +import os +from typing import TYPE_CHECKING, List, Optional + +import huggingface_hub +import torch +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.memory_management.manager import MemoryManager +from toolkit.metadata import get_meta_for_safetensors +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.quantize import quantize, get_qtype, quantize_model + +from transformers import AutoProcessor, Mistral3ForConditionalGeneration +from .src.model import Flux2, Flux2Params +from .src.pipeline import Flux2Pipeline +from .src.autoencoder import AutoEncoder, AutoEncoderParams +from safetensors.torch import load_file, save_file +from PIL import Image +import torch.nn.functional as F + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +from .src.sampling import ( + batched_prc_img, + batched_prc_txt, + encode_image_refs, + scatter_ids, +) + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": True, +} + +MISTRAL_PATH = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" +FLUX2_VAE_FILENAME = "flux2-vae.sft" +FLUX2_TRANSFORMER_FILENAME = "flux-dev-dummy.sft" + +HF_TOKEN = os.getenv("HF_TOKEN", None) + + +class Flux2Model(BaseModel): + arch = "flux2" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["Flux2"] + # control images will come in as a list for encoding some things if true + self.has_multiple_control_images = True + # do not resize control images + self.use_raw_control_images = True + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return 16 + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading Flux2 model") + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + transformer_path = model_path + + self.print_and_status_update("Loading transformer") + with torch.device("meta"): + transformer = Flux2(Flux2Params()) + + # use local path if provided + if os.path.exists(os.path.join(transformer_path, FLUX2_TRANSFORMER_FILENAME)): + transformer_path = os.path.join( + transformer_path, FLUX2_TRANSFORMER_FILENAME + ) + + if not os.path.exists(transformer_path): + # assume it is from the hub + transformer_path = huggingface_hub.hf_hub_download( + repo_id=model_path, + filename=FLUX2_TRANSFORMER_FILENAME, + token=HF_TOKEN, + ) + + transformer_state_dict = load_file(transformer_path, device="cpu") + + # cast to dtype + for key in transformer_state_dict: + transformer_state_dict[key] = transformer_state_dict[key].to(dtype) + + transformer.load_state_dict(transformer_state_dict, assign=True) + + transformer.to(self.quantize_device, dtype=dtype) + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() + else: + transformer.to(self.device_torch, dtype=dtype) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_transformer_percent > 0 + ): + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ) + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to("cpu") + + self.print_and_status_update("Loading Mistral") + + text_encoder: Mistral3ForConditionalGeneration = ( + Mistral3ForConditionalGeneration.from_pretrained( + MISTRAL_PATH, + torch_dtype=dtype, + ) + ) + text_encoder.to(self.device_torch, dtype=dtype) + + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Mistral") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype)) + freeze(text_encoder) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) + + tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH) + + self.print_and_status_update("Loading VAE") + vae_path = self.model_config.vae_path + + if os.path.exists(os.path.join(model_path, FLUX2_VAE_FILENAME)): + vae_path = os.path.join(model_path, FLUX2_VAE_FILENAME) + + if vae_path is None or not os.path.exists(vae_path): + # assume it is from the hub + vae_path = huggingface_hub.hf_hub_download( + repo_id=model_path, + filename=FLUX2_VAE_FILENAME, + token=HF_TOKEN, + ) + with torch.device("meta"): + vae = AutoEncoder(AutoEncoderParams()) + + vae_state_dict = load_file(vae_path, device="cpu") + + # cast to dtype + for key in vae_state_dict: + vae_state_dict[key] = vae_state_dict[key].to(dtype) + + vae.load_state_dict(vae_state_dict, assign=True) + + self.noise_scheduler = Flux2Model.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + pipe: Flux2Pipeline = Flux2Pipeline( + scheduler=self.noise_scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + transformer=None, + ) + # for quantization, it works best to do these after making the pipe + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder] + tokenizer = [pipe.tokenizer] + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = Flux2Model.get_train_scheduler() + + pipeline: Flux2Pipeline = Flux2Pipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: Flux2Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + gen_config.width = ( + gen_config.width // self.get_bucket_divisibility() + ) * self.get_bucket_divisibility() + gen_config.height = ( + gen_config.height // self.get_bucket_divisibility() + ) * self.get_bucket_divisibility() + + control_img_list = [] + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + elif gen_config.ctrl_img_1 is not None: + control_img = Image.open(gen_config.ctrl_img_1) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + if gen_config.ctrl_img_2 is not None: + control_img = Image.open(gen_config.ctrl_img_2) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + if gen_config.ctrl_img_3 is not None: + control_img = Image.open(gen_config.ctrl_img_3) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + control_img_list=control_img_list, + **extra, + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + guidance_embedding_scale: float, + batch: "DataLoaderBatchDTO" = None, + **kwargs, + ): + with torch.no_grad(): + txt, txt_ids = batched_prc_txt(text_embeddings.text_embeds) + packed_latents, img_ids = batched_prc_img(latent_model_input) + + # prepare image conditioning if any + img_cond_seq: torch.Tensor | None = None + img_cond_seq_ids: torch.Tensor | None = None + + # handle control images + if batch.control_tensor_list is not None: + batch_size, num_channels_latents, height, width = ( + latent_model_input.shape + ) + + control_image_max_res = 1024 * 1024 + if self.model_config.model_kwargs.get("match_target_res", False): + # use the current target size to set the control image res + control_image_res = ( + height + * self.pipeline.vae_scale_factor + * width + * self.pipeline.vae_scale_factor + ) + control_image_max_res = control_image_res + + if len(batch.control_tensor_list) != batch_size: + raise ValueError( + "Control tensor list length does not match batch size" + ) + for control_tensor_list in batch.control_tensor_list: + # control tensor list is a list of tensors for this batch item + controls = [] + # pack control + for control_img in control_tensor_list: + # control images are 0 - 1 scale, shape (1, ch, height, width) + control_img = control_img.to( + self.device_torch, dtype=self.torch_dtype + ) + # if it is only 3 dim, add batch dim + if len(control_img.shape) == 3: + control_img = control_img.unsqueeze(0) + + # resize to fit within max res while keeping aspect ratio + if self.model_config.model_kwargs.get( + "match_target_res", False + ): + ratio = control_img.shape[2] / control_img.shape[3] + c_width = math.sqrt(control_image_res * ratio) + c_height = c_width / ratio + + c_width = round(c_width / 32) * 32 + c_height = round(c_height / 32) * 32 + + control_img = F.interpolate( + control_img, size=(c_height, c_width), mode="bilinear" + ) + + # scale to -1 to 1 + control_img = control_img * 2 - 1 + controls.append(control_img) + + img_cond_seq_item, img_cond_seq_ids_item = encode_image_refs( + self.vae, controls, limit_pixels=control_image_max_res + ) + if img_cond_seq is None: + img_cond_seq = img_cond_seq_item + img_cond_seq_ids = img_cond_seq_ids_item + else: + img_cond_seq = torch.cat( + (img_cond_seq, img_cond_seq_item), dim=0 + ) + img_cond_seq_ids = torch.cat( + (img_cond_seq_ids, img_cond_seq_ids_item), dim=0 + ) + + img_input = packed_latents + img_input_ids = img_ids + + if img_cond_seq is not None: + assert img_cond_seq_ids is not None, ( + "You need to provide either both or neither of the sequence conditioning" + ) + img_input = torch.cat((img_input, img_cond_seq), dim=1) + img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) + + guidance_vec = torch.full( + (img_input.shape[0],), + guidance_embedding_scale, + device=img_input.device, + dtype=img_input.dtype, + ) + + cast_dtype = self.model.dtype + + packed_noise_pred = self.transformer( + x=img_input.to(self.device_torch, cast_dtype), + x_ids=img_input_ids.to(self.device_torch), + timesteps=timestep.to(self.device_torch, cast_dtype) / 1000, + ctx=txt.to(self.device_torch, cast_dtype), + ctx_ids=txt_ids.to(self.device_torch), + guidance=guidance_vec.to(self.device_torch, cast_dtype), + ) + + if img_cond_seq is not None: + packed_noise_pred = packed_noise_pred[:, : packed_latents.shape[1]] + + if isinstance(packed_noise_pred, QTensor): + packed_noise_pred = packed_noise_pred.dequantize() + + noise_pred = torch.cat(scatter_ids(packed_noise_pred, img_ids)).squeeze(2) + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt( + prompt, device=self.device_torch + ) + pe = PromptEmbeds(prompt_embeds) + return pe + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + if not output_path.endswith(".safetensors"): + output_path = output_path + ".safetensors" + # only save the unet + transformer: Flux2 = unwrap_model(self.model) + state_dict = transformer.state_dict() + save_dict = {} + for k, v in state_dict.items(): + if isinstance(v, QTensor): + v = v.dequantize() + save_dict[k] = v.clone().to("cpu", dtype=save_dtype) + + meta = get_meta_for_safetensors(meta, name="flux2") + save_file(save_dict, output_path, metadata=meta) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get("noise") + batch = kwargs.get("batch") + return (noise - batch.latents).detach() + + def get_base_model_version(self): + return "flux2" + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["double_blocks", "single_blocks"] + + def convert_lora_weights_before_save(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd + + def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + # Move to vae to device if on cpu + if self.vae.device == torch.device("cpu"): + self.vae.to(device) + # move to device and dtype + image_list = [image.to(device, dtype=dtype) for image in image_list] + images = torch.stack(image_list).to(device, dtype=dtype) + + latents = self.vae.encode(images) + + return latents diff --git a/extensions_built_in/diffusion_models/flux2/src/__init__.py b/extensions_built_in/diffusion_models/flux2/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions_built_in/diffusion_models/flux2/src/autoencoder.py b/extensions_built_in/diffusion_models/flux2/src/autoencoder.py new file mode 100644 index 00000000..d5baef27 --- /dev/null +++ b/extensions_built_in/diffusion_models/flux2/src/autoencoder.py @@ -0,0 +1,370 @@ +from dataclasses import dataclass, field + +import torch +from einops import rearrange +from torch import Tensor, nn +import math + + +@dataclass +class AutoEncoderParams: + resolution: int = 256 + in_channels: int = 3 + ch: int = 128 + out_ch: int = 3 + ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4]) + num_res_blocks: int = 2 + z_channels: int = 32 + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = nn.GroupNorm( + num_groups=32, num_channels=out_channels, eps=1e-6, affine=True + ) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1) + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=block_in, eps=1e-6, affine=True + ) + self.conv_out = nn.Conv2d( + block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + h = self.quant_conv(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1) + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=block_in, eps=1e-6, affine=True + ) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + z = self.post_quant_conv(z) + + # get dtype for proper tracing + upscale_dtype = next(self.up.parameters()).dtype + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # cast to proper dtype + h = h.to(upscale_dtype) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.params = params + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + + self.bn_eps = 1e-4 + self.bn_momentum = 0.1 + self.ps = [2, 2] + self.bn = torch.nn.BatchNorm2d( + math.prod(self.ps) * params.z_channels, + eps=self.bn_eps, + momentum=self.bn_momentum, + affine=False, + track_running_stats=True, + ) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def normalize(self, z): + self.bn.eval() + return self.bn(z) + + def inv_normalize(self, z): + self.bn.eval() + s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps) + m = self.bn.running_mean.view(1, -1, 1, 1) + return z * s + m + + def encode(self, x: Tensor) -> Tensor: + moments = self.encoder(x) + mean = torch.chunk(moments, 2, dim=1)[0] + + z = rearrange( + mean, + "... c (i pi) (j pj) -> ... (c pi pj) i j", + pi=self.ps[0], + pj=self.ps[1], + ) + z = self.normalize(z) + return z + + def decode(self, z: Tensor) -> Tensor: + z = self.inv_normalize(z) + z = rearrange( + z, + "... (c pi pj) i j -> ... c (i pi) (j pj)", + pi=self.ps[0], + pj=self.ps[1], + ) + dec = self.decoder(z) + return dec diff --git a/extensions_built_in/diffusion_models/flux2/src/model.py b/extensions_built_in/diffusion_models/flux2/src/model.py new file mode 100644 index 00000000..a78f6a28 --- /dev/null +++ b/extensions_built_in/diffusion_models/flux2/src/model.py @@ -0,0 +1,520 @@ +import torch +from einops import rearrange +from torch import Tensor, nn +import torch.utils.checkpoint as ckpt +import math +from dataclasses import dataclass, field + + +@dataclass +class Flux2Params: + in_channels: int = 128 + context_in_dim: int = 15360 + hidden_size: int = 6144 + num_heads: int = 48 + depth: int = 8 + depth_single_blocks: int = 48 + axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32]) + theta: int = 2000 + mlp_ratio: float = 3.0 + + +class FakeConfig: + # for diffusers compatability + def __init__(self): + self.patch_size = 1 + + +class Flux2(nn.Module): + def __init__(self, params: Flux2Params): + super().__init__() + self.config = FakeConfig() + + self.in_channels = params.in_channels + self.out_channels = params.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError( + f"Got {params.axes_dim} but expected positional dim {pe_dim}" + ) + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND( + dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim + ) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=False) + self.time_in = MLPEmbedder( + in_dim=256, hidden_dim=self.hidden_size, disable_bias=True + ) + self.guidance_in = MLPEmbedder( + in_dim=256, hidden_dim=self.hidden_size, disable_bias=True + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + ) + for _ in range(params.depth_single_blocks) + ] + ) + + self.double_stream_modulation_img = Modulation( + self.hidden_size, + double=True, + disable_bias=True, + ) + self.double_stream_modulation_txt = Modulation( + self.hidden_size, + double=True, + disable_bias=True, + ) + self.single_stream_modulation = Modulation( + self.hidden_size, double=False, disable_bias=True + ) + + self.final_layer = LastLayer( + self.hidden_size, + self.out_channels, + ) + + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def forward( + self, + x: Tensor, + x_ids: Tensor, + timesteps: Tensor, + ctx: Tensor, + ctx_ids: Tensor, + guidance: Tensor, + ): + num_txt_tokens = ctx.shape[1] + + timestep_emb = timestep_embedding(timesteps, 256) + vec = self.time_in(timestep_emb) + guidance_emb = timestep_embedding(guidance, 256) + vec = vec + self.guidance_in(guidance_emb) + + double_block_mod_img = self.double_stream_modulation_img(vec) + double_block_mod_txt = self.double_stream_modulation_txt(vec) + single_block_mod, _ = self.single_stream_modulation(vec) + + img = self.img_in(x) + txt = self.txt_in(ctx) + + pe_x = self.pe_embedder(x_ids) + pe_ctx = self.pe_embedder(ctx_ids) + + for block in self.double_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + img.requires_grad_(True) + txt.requires_grad_(True) + img, txt = ckpt.checkpoint( + block, + img, + txt, + pe_x, + pe_ctx, + double_block_mod_img, + double_block_mod_txt, + ) + else: + img, txt = block( + img, + txt, + pe_x, + pe_ctx, + double_block_mod_img, + double_block_mod_txt, + ) + + img = torch.cat((txt, img), dim=1) + pe = torch.cat((pe_ctx, pe_x), dim=2) + + for i, block in enumerate(self.single_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + img.requires_grad_(True) + img = ckpt.checkpoint( + block, + img, + pe, + single_block_mod, + ) + else: + img = block( + img, + pe, + single_block_mod, + ) + + img = img[:, num_txt_tokens:, ...] + + img = self.final_layer(img, vec) + return img + + +class SelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=False) + + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim, bias=False) + + +class SiLUActivation(nn.Module): + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: Tensor) -> Tensor: + x1, x2 = x.chunk(2, dim=-1) + return self.gate_fn(x1) * x2 + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool, disable_bias: bool = False): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=not disable_bias) + + def forward(self, vec: torch.Tensor): + out = self.lin(nn.functional.silu(vec)) + if out.ndim == 2: + out = out[:, None, :] + out = out.chunk(self.multiplier, dim=-1) + return out[:3], out[3:] if self.is_double else None + + +class LastLayer(nn.Module): + def __init__( + self, + hidden_size: int, + out_channels: int, + ): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=False) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=False) + ) + + def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: + mod = self.adaLN_modulation(vec) + shift, scale = mod.chunk(2, dim=-1) + if shift.ndim == 2: + shift = shift[:, None, :] + scale = scale[:, None, :] + x = (1 + scale) * self.norm_final(x) + shift + x = self.linear(x) + return x + + +class SingleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + ): + super().__init__() + + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = head_dim**-0.5 + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp_mult_factor = 2 + + self.linear1 = nn.Linear( + hidden_size, + hidden_size * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, + bias=False, + ) + + self.linear2 = nn.Linear( + hidden_size + self.mlp_hidden_dim, hidden_size, bias=False + ) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = SiLUActivation() + + def forward( + self, + x: Tensor, + pe: Tensor, + mod: tuple[Tensor, Tensor], + ) -> Tensor: + mod_shift, mod_scale, mod_gate = mod + x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift + + qkv, mlp = torch.split( + self.linear1(x_mod), + [3 * self.hidden_size, self.mlp_hidden_dim * self.mlp_mult_factor], + dim=-1, + ) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + attn = attention(q, k, v, pe) + + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod_gate * output + + +class DoubleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float, + ): + super().__init__() + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + assert hidden_size % num_heads == 0, ( + f"{hidden_size=} must be divisible by {num_heads=}" + ) + + self.hidden_size = hidden_size + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.mlp_mult_factor = 2 + + self.img_attn = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + ) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim * self.mlp_mult_factor, bias=False), + SiLUActivation(), + nn.Linear(mlp_hidden_dim, hidden_size, bias=False), + ) + + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear( + hidden_size, + mlp_hidden_dim * self.mlp_mult_factor, + bias=False, + ), + SiLUActivation(), + nn.Linear(mlp_hidden_dim, hidden_size, bias=False), + ) + + def forward( + self, + img: Tensor, + txt: Tensor, + pe: Tensor, + pe_ctx: Tensor, + mod_img: tuple[Tensor, Tensor], + mod_txt: tuple[Tensor, Tensor], + ) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = mod_img + txt_mod1, txt_mod2 = mod_txt + + img_mod1_shift, img_mod1_scale, img_mod1_gate = img_mod1 + img_mod2_shift, img_mod2_scale, img_mod2_gate = img_mod2 + txt_mod1_shift, txt_mod1_scale, txt_mod1_gate = txt_mod1 + txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = txt_mod2 + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift + + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange( + img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift + + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange( + txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + pe = torch.cat((pe_ctx, pe), dim=2) + attn = attention(q, k, v, pe) + txt_attn, img_attn = attn[:, : txt_q.shape[2]], attn[:, txt_q.shape[2] :] + + # calculate the img blocks + img = img + img_mod1_gate * self.img_attn.proj(img_attn) + img = img + img_mod2_gate * self.img_mlp( + (1 + img_mod2_scale) * (self.img_norm2(img)) + img_mod2_shift + ) + + # calculate the txt blocks + txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2_gate * self.txt_mlp( + (1 + txt_mod2_scale) * (self.txt_norm2(txt)) + txt_mod2_shift + ) + return img, txt + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, disable_bias: bool = False): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=not disable_bias) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=not disable_bias) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + emb = torch.cat( + [ + rope(ids[..., i], self.axes_dim[i], self.theta) + for i in range(len(self.axes_dim)) + ], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, device=t.device, dtype=torch.float32) + / half + ) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack( + [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1 + ) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/extensions_built_in/diffusion_models/flux2/src/pipeline.py b/extensions_built_in/diffusion_models/flux2/src/pipeline.py new file mode 100644 index 00000000..ee0e8490 --- /dev/null +++ b/extensions_built_in/diffusion_models/flux2/src/pipeline.py @@ -0,0 +1,358 @@ +from typing import List, Optional, Union + +import numpy as np +import torch +import PIL.Image +from dataclasses import dataclass +from typing import List, Union + +from diffusers.image_processor import VaeImageProcessor +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + logging, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +from diffusers.utils import BaseOutput + +from .autoencoder import AutoEncoder +from .model import Flux2 + +from einops import rearrange + +from transformers import AutoProcessor, Mistral3ForConditionalGeneration + +from .sampling import ( + get_schedule, + batched_prc_img, + batched_prc_txt, + encode_image_refs, + scatter_ids, +) + + +@dataclass +class Flux2ImagePipelineOutput(BaseOutput): + images: Union[List[PIL.Image.Image], np.ndarray] + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object +attribution and actions without speculation.""" +OUTPUT_LAYERS = [10, 20, 30] +MAX_LENGTH = 512 + + +class Flux2Pipeline(DiffusionPipeline): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoEncoder, + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + transformer: Flux2, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 16 # 8x plus 2x pixel shuffle + self.num_channels_latents = 128 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = 64 + + def format_input( + self, + txt: list[str], + ) -> list[list[dict]]: + # Remove [IMG] tokens from prompts to avoid Pixtral validation issues + # when truncation is enabled. The processor counts [IMG] tokens and fails + # if the count changes after truncation. + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in txt] + + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": SYSTEM_MESSAGE}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + + def _get_mistral_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 512, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + if not isinstance(prompt, list): + prompt = [prompt] + + # Format input messages + messages_batch = self.format_input(txt=prompt) + + # Process all messages at once + # with image processing a too short max length can throw an error in here. + try: + inputs = self.tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + except ValueError as e: + print( + f"Error processing input: {e}, your max length is probably too short, when you have images in the input." + ) + raise e + + # Move to device + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + # Forward pass through the model + output = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1) + prompt_embeds = rearrange(out, "b c l d -> b l (c d)") + + # they don't return attention mask, so we create it here + return prompt_embeds, None + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_mistral_prompt_embeds( + prompt, device, max_sequence_length=max_sequence_length + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + return prompt_embeds, prompt_embeds_mask + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: Optional[float] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + max_sequence_length: int = 512, + control_img_list: Optional[List[PIL.Image.Image]] = None, + ): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode the prompt + + prompt_embeds, _ = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + txt, txt_ids = batched_prc_txt(prompt_embeds) + + # 4. Prepare latent variables\ + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + self.num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + packed_latents, img_ids = batched_prc_img(latents) + + timesteps = get_schedule(num_inference_steps, packed_latents.shape[1]) + + self._num_timesteps = len(timesteps) + + guidance_vec = torch.full( + (packed_latents.shape[0],), + guidance_scale, + device=packed_latents.device, + dtype=packed_latents.dtype, + ) + + if control_img_list is not None and len(control_img_list) > 0: + img_cond_seq, img_cond_seq_ids = encode_image_refs( + self.vae, control_img_list + ) + else: + img_cond_seq, img_cond_seq_ids = None, None + + # 6. Denoising loop + i = 0 + with self.progress_bar(total=num_inference_steps) as progress_bar: + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + if self.interrupt: + continue + t_vec = torch.full( + (packed_latents.shape[0],), + t_curr, + dtype=packed_latents.dtype, + device=packed_latents.device, + ) + + self._current_timestep = t_curr + img_input = packed_latents + img_input_ids = img_ids + + if img_cond_seq is not None: + assert img_cond_seq_ids is not None, ( + "You need to provide either both or neither of the sequence conditioning" + ) + img_input = torch.cat((img_input, img_cond_seq), dim=1) + img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) + + pred = self.transformer( + x=img_input, + x_ids=img_input_ids, + timesteps=t_vec, + ctx=txt, + ctx_ids=txt_ids, + guidance=guidance_vec, + ) + + if img_cond_seq is not None: + pred = pred[:, : packed_latents.shape[1]] + + packed_latents = packed_latents + (t_prev - t_curr) * pred + i += 1 + progress_bar.update(1) + + self._current_timestep = None + + # 7. Post-processing + latents = torch.cat(scatter_ids(packed_latents, img_ids)).squeeze(2) + + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + image = self.vae.decode(latents).float() + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Flux2ImagePipelineOutput(images=image) diff --git a/extensions_built_in/diffusion_models/flux2/src/sampling.py b/extensions_built_in/diffusion_models/flux2/src/sampling.py new file mode 100644 index 00000000..b02af2d6 --- /dev/null +++ b/extensions_built_in/diffusion_models/flux2/src/sampling.py @@ -0,0 +1,365 @@ +import math +from typing import Callable, Union + +import torch +from einops import rearrange +from PIL import Image +from torch import Tensor + +from .model import Flux2 +import torchvision + + +def compress_time(t_ids: Tensor) -> Tensor: + assert t_ids.ndim == 1 + t_ids_max = torch.max(t_ids) + t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype) + t_unique_sorted_ids = torch.unique(t_ids, sorted=True) + t_remap[t_unique_sorted_ids] = torch.arange( + len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype + ) + t_ids_compressed = t_remap[t_ids] + return t_ids_compressed + + +def scatter_ids(x: Tensor, x_ids: Tensor) -> list[Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + t_coords = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + t_ids = pos[:, 0].to(torch.int64) + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + t_ids_cmpr = compress_time(t_ids) + + t = torch.max(t_ids_cmpr) + 1 + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids + + out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w)) + t_coords.append(torch.unique(t_ids, sorted=True)) + return x_list + + +def encode_image_refs( + ae, + img_ctx: Union[list[Image.Image], list[torch.Tensor]], + scale=10, + limit_pixels=1024**2, +): + if not img_ctx: + return None, None + + img_ctx_prep = default_prep(img=img_ctx, limit_pixels=limit_pixels) + if not isinstance(img_ctx_prep, list): + img_ctx_prep = [img_ctx_prep] + + # Encode each reference image + encoded_refs = [] + for img in img_ctx_prep: + if img.ndim == 3: + img = img.unsqueeze(0) + encoded = ae.encode(img.to(ae.device, ae.dtype))[0] + encoded_refs.append(encoded) + + # Create time offsets for each reference + t_off = [scale + scale * t for t in torch.arange(0, len(encoded_refs))] + t_off = [t.view(-1) for t in t_off] + + # Process with position IDs + ref_tokens, ref_ids = listed_prc_img(encoded_refs, t_coord=t_off) + + # Concatenate all references along sequence dimension + ref_tokens = torch.cat(ref_tokens, dim=0) # (total_ref_tokens, C) + ref_ids = torch.cat(ref_ids, dim=0) # (total_ref_tokens, 4) + + # Add batch dimension + ref_tokens = ref_tokens.unsqueeze(0) # (1, total_ref_tokens, C) + ref_ids = ref_ids.unsqueeze(0) # (1, total_ref_tokens, 4) + + return ref_tokens.to(torch.bfloat16), ref_ids + + +def prc_txt( + x: Tensor, t_coord: Tensor | None = None, l_coord: Tensor | None = None +) -> tuple[Tensor, Tensor]: + assert l_coord is None, "l_coord not supported for txts" + + _l, _ = x.shape # noqa: F841 + + coords = { + "t": torch.arange(1) if t_coord is None else t_coord, + "h": torch.arange(1), # dummy dimension + "w": torch.arange(1), # dummy dimension + "l": torch.arange(_l), + } + x_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"]) + return x, x_ids.to(x.device) + + +def batched_wrapper(fn): + def batched_prc( + x: Tensor, t_coord: Tensor | None = None, l_coord: Tensor | None = None + ) -> tuple[Tensor, Tensor]: + results = [] + for i in range(len(x)): + results.append( + fn( + x[i], + t_coord[i] if t_coord is not None else None, + l_coord[i] if l_coord is not None else None, + ) + ) + x, x_ids = zip(*results) + return torch.stack(x), torch.stack(x_ids) + + return batched_prc + + +def listed_wrapper(fn): + def listed_prc( + x: list[Tensor], + t_coord: list[Tensor] | None = None, + l_coord: list[Tensor] | None = None, + ) -> tuple[list[Tensor], list[Tensor]]: + results = [] + for i in range(len(x)): + results.append( + fn( + x[i], + t_coord[i] if t_coord is not None else None, + l_coord[i] if l_coord is not None else None, + ) + ) + x, x_ids = zip(*results) + return list(x), list(x_ids) + + return listed_prc + + +def prc_img( + x: Tensor, t_coord: Tensor | None = None, l_coord: Tensor | None = None +) -> tuple[Tensor, Tensor]: + c, h, w = x.shape # noqa: F841 + x_coords = { + "t": torch.arange(1) if t_coord is None else t_coord, + "h": torch.arange(h), + "w": torch.arange(w), + "l": torch.arange(1) if l_coord is None else l_coord, + } + x_ids = torch.cartesian_prod( + x_coords["t"], x_coords["h"], x_coords["w"], x_coords["l"] + ) + x = rearrange(x, "c h w -> (h w) c") + return x, x_ids.to(x.device) + + +listed_prc_img = listed_wrapper(prc_img) +batched_prc_img = batched_wrapper(prc_img) +batched_prc_txt = batched_wrapper(prc_txt) + + +def center_crop_to_multiple_of_x( + img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor], x: int +) -> Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor]: + if isinstance(img, list): + return [center_crop_to_multiple_of_x(_img, x) for _img in img] # type: ignore + + if isinstance(img, torch.Tensor): + h, w = img.shape[-2], img.shape[-1] + else: + w, h = img.size + new_w = (w // x) * x + new_h = (h // x) * x + + left = (w - new_w) // 2 + top = (h - new_h) // 2 + right = left + new_w + bottom = top + new_h + + if isinstance(img, torch.Tensor): + return img[..., top:bottom, left:right] + resized = img.crop((left, top, right, bottom)) + return resized + + +def cap_pixels( + img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor], k +): + if isinstance(img, list): + return [cap_pixels(_img, k) for _img in img] + if isinstance(img, torch.Tensor): + h, w = img.shape[-2], img.shape[-1] + else: + w, h = img.size + pixel_count = w * h + + if pixel_count <= k: + return img + + # Scaling factor to reduce total pixels below K + scale = math.sqrt(k / pixel_count) + new_w = int(w * scale) + new_h = int(h * scale) + + if isinstance(img, torch.Tensor): + did_expand = False + if img.ndim == 3: + img = img.unsqueeze(0) + did_expand = True + img = torch.nn.functional.interpolate( + img, + size=(new_h, new_w), + mode="bicubic", + align_corners=False, + ) + if did_expand: + img = img.squeeze(0) + return img + return img.resize((new_w, new_h), Image.Resampling.LANCZOS) + + +def cap_min_pixels( + img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor], + max_ar=8, + min_sidelength=64, +): + if isinstance(img, list): + return [ + cap_min_pixels(_img, max_ar=max_ar, min_sidelength=min_sidelength) + for _img in img + ] + if isinstance(img, torch.Tensor): + h, w = img.shape[-2], img.shape[-1] + else: + w, h = img.size + if w < min_sidelength or h < min_sidelength: + raise ValueError( + f"Skipping due to minimal sidelength underschritten h {h} w {w}" + ) + if w / h > max_ar or h / w > max_ar: + raise ValueError(f"Skipping due to maximal ar overschritten h {h} w {w}") + return img + + +def to_rgb( + img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor], +) -> Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor]: + if isinstance(img, list): + return [ + to_rgb( + _img, + ) + for _img in img + ] + if isinstance(img, torch.Tensor): + return img # assume already in tensor format + return img.convert("RGB") + + +def default_images_prep( + x: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor], +) -> torch.Tensor | list[torch.Tensor]: + if isinstance(x, list): + return [default_images_prep(e) for e in x] # type: ignore + if isinstance(x, torch.Tensor): + return x # assume already in tensor format + x_tensor = torchvision.transforms.ToTensor()(x) + return 2 * x_tensor - 1 + + +def default_prep( + img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor], + limit_pixels: int, + ensure_multiple: int = 16, +) -> torch.Tensor | list[torch.Tensor]: + # if passing a tensor, assume it is -1 to 1 already + img_rgb = to_rgb(img) + img_min = cap_min_pixels(img_rgb) # type: ignore + img_cap = cap_pixels(img_min, limit_pixels) # type: ignore + img_crop = center_crop_to_multiple_of_x(img_cap, ensure_multiple) # type: ignore + img_tensor = default_images_prep(img_crop) + return img_tensor + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # estimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: Flux2, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + # sampling parameters + timesteps: list[float], + guidance: float, + # extra img tokens (sequence-wise) + img_cond_seq: Tensor | None = None, + img_cond_seq_ids: Tensor | None = None, +): + guidance_vec = torch.full( + (img.shape[0],), guidance, device=img.device, dtype=img.dtype + ) + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + img_input = img + img_input_ids = img_ids + if img_cond_seq is not None: + assert img_cond_seq_ids is not None, ( + "You need to provide either both or neither of the sequence conditioning" + ) + img_input = torch.cat((img_input, img_cond_seq), dim=1) + img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) + pred = model( + x=img_input, + x_ids=img_input_ids, + timesteps=t_vec, + ctx=txt, + ctx_ids=txt_ids, + guidance=guidance_vec, + ) + if img_input_ids is not None: + pred = pred[:, : img.shape[1]] + + img = img + (t_prev - t_curr) * pred + + return img diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index 0ef68e57..d328a2bc 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -258,7 +258,13 @@ export const modelArchs: ModelArch[] = [ ], }, disableSections: ['network.conv'], - additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'model.multistage', 'model.layer_offloading'], + additionalSections: [ + 'sample.ctrl_img', + 'datasets.num_frames', + 'model.low_vram', + 'model.multistage', + 'model.layer_offloading', + ], accuracyRecoveryAdapters: { '4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_i2v_torchao_uint4.safetensors', }, @@ -459,6 +465,37 @@ export const modelArchs: ModelArch[] = [ disableSections: ['network.conv'], additionalSections: ['datasets.control_path', 'sample.ctrl_img'], }, + { + name: 'flux2', + label: 'FLUX.2(DUMMY)', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['ai-toolkit/f2-dummy', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + 'config.process[0].model.model_kwargs': [ + { + match_target_res: false, + }, + {}, + ], + }, + disableSections: ['network.conv'], + additionalSections: [ + 'datasets.multi_control_paths', + 'sample.multi_ctrl_imgs', + 'model.low_vram', + 'model.layer_offloading', + 'model.qie.match_target_res', + ], + }, ].sort((a, b) => { // Sort by label, case-insensitive return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' }); From dadbeda1978deccee2f572ee34c90d49dac69675 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 23 Nov 2025 10:51:50 -0700 Subject: [PATCH 2/3] Update test weights --- extensions_built_in/diffusion_models/flux2/flux2_model.py | 2 +- ui/src/app/jobs/new/options.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions_built_in/diffusion_models/flux2/flux2_model.py b/extensions_built_in/diffusion_models/flux2/flux2_model.py index 11154e4d..ca48ebd9 100644 --- a/extensions_built_in/diffusion_models/flux2/flux2_model.py +++ b/extensions_built_in/diffusion_models/flux2/flux2_model.py @@ -48,7 +48,7 @@ scheduler_config = { MISTRAL_PATH = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" FLUX2_VAE_FILENAME = "flux2-vae.sft" -FLUX2_TRANSFORMER_FILENAME = "flux-dev-dummy.sft" +FLUX2_TRANSFORMER_FILENAME = "flux2-final-dev.sft" HF_TOKEN = os.getenv("HF_TOKEN", None) diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index d328a2bc..4c730a5e 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -471,7 +471,7 @@ export const modelArchs: ModelArch[] = [ group: 'image', defaults: { // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['ai-toolkit/f2-dummy', defaultNameOrPath], + 'config.process[0].model.name_or_path': ['ostris/f2', defaultNameOrPath], 'config.process[0].model.quantize': [true, false], 'config.process[0].model.quantize_te': [true, false], 'config.process[0].model.low_vram': [true, false], From 01cf480233438578ecc8095b0f1d63c0a89594de Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 25 Nov 2025 08:52:19 -0700 Subject: [PATCH 3/3] Add FLUX.2 official weights --- extensions_built_in/diffusion_models/flux2/flux2_model.py | 4 ++-- ui/src/app/jobs/new/options.ts | 4 ++-- version.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/extensions_built_in/diffusion_models/flux2/flux2_model.py b/extensions_built_in/diffusion_models/flux2/flux2_model.py index ca48ebd9..6467ead0 100644 --- a/extensions_built_in/diffusion_models/flux2/flux2_model.py +++ b/extensions_built_in/diffusion_models/flux2/flux2_model.py @@ -47,8 +47,8 @@ scheduler_config = { } MISTRAL_PATH = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" -FLUX2_VAE_FILENAME = "flux2-vae.sft" -FLUX2_TRANSFORMER_FILENAME = "flux2-final-dev.sft" +FLUX2_VAE_FILENAME = "ae.safetensors" +FLUX2_TRANSFORMER_FILENAME = "flux2-dev.safetensors" HF_TOKEN = os.getenv("HF_TOKEN", None) diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index 4c730a5e..5c7fe1f0 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -467,11 +467,11 @@ export const modelArchs: ModelArch[] = [ }, { name: 'flux2', - label: 'FLUX.2(DUMMY)', + label: 'FLUX.2', group: 'image', defaults: { // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['ostris/f2', defaultNameOrPath], + 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.2-dev', defaultNameOrPath], 'config.process[0].model.quantize': [true, false], 'config.process[0].model.quantize_te': [true, false], 'config.process[0].model.low_vram': [true, false], diff --git a/version.py b/version.py index 9b8720f5..c6746c8c 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.7.4" \ No newline at end of file +VERSION = "0.7.5" \ No newline at end of file