Merge pull request #525 from ostris/flux2

Add support for FLUX.2
This commit is contained in:
Jaret Burkett
2025-11-25 07:53:36 -08:00
committed by GitHub
10 changed files with 2145 additions and 2 deletions

View File

@@ -5,6 +5,7 @@ from .omnigen2 import OmniGen2Model
from .flux_kontext import FluxKontextModel
from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel
from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel
from .flux2 import Flux2Model
AI_TOOLKIT_MODELS = [
# put a list of models here
@@ -21,4 +22,5 @@ AI_TOOLKIT_MODELS = [
QwenImageModel,
QwenImageEditModel,
QwenImageEditPlusModel,
Flux2Model,
]

View File

@@ -0,0 +1 @@
from .flux2_model import Flux2Model

View File

@@ -0,0 +1,490 @@
import math
import os
from typing import TYPE_CHECKING, List, Optional
import huggingface_hub
import torch
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.memory_management.manager import MemoryManager
from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.base_model import BaseModel
from toolkit.basic import flush
from toolkit.prompt_utils import PromptEmbeds
from toolkit.samplers.custom_flowmatch_sampler import (
CustomFlowMatchEulerDiscreteScheduler,
)
from toolkit.dequantize import patch_dequantization_on_save
from toolkit.accelerator import unwrap_model
from optimum.quanto import freeze, QTensor
from toolkit.util.quantize import quantize, get_qtype, quantize_model
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
from .src.model import Flux2, Flux2Params
from .src.pipeline import Flux2Pipeline
from .src.autoencoder import AutoEncoder, AutoEncoderParams
from safetensors.torch import load_file, save_file
from PIL import Image
import torch.nn.functional as F
if TYPE_CHECKING:
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from .src.sampling import (
batched_prc_img,
batched_prc_txt,
encode_image_refs,
scatter_ids,
)
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": 0.5,
"max_image_seq_len": 4096,
"max_shift": 1.15,
"num_train_timesteps": 1000,
"shift": 3.0,
"use_dynamic_shifting": True,
}
MISTRAL_PATH = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
FLUX2_VAE_FILENAME = "ae.safetensors"
FLUX2_TRANSFORMER_FILENAME = "flux2-dev.safetensors"
HF_TOKEN = os.getenv("HF_TOKEN", None)
class Flux2Model(BaseModel):
arch = "flux2"
def __init__(
self,
device,
model_config: ModelConfig,
dtype="bf16",
custom_pipeline=None,
noise_scheduler=None,
**kwargs,
):
super().__init__(
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ["Flux2"]
# control images will come in as a list for encoding some things if true
self.has_multiple_control_images = True
# do not resize control images
self.use_raw_control_images = True
# static method to get the noise scheduler
@staticmethod
def get_train_scheduler():
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
def get_bucket_divisibility(self):
return 16
def load_model(self):
dtype = self.torch_dtype
self.print_and_status_update("Loading Flux2 model")
# will be updated if we detect a existing checkpoint in training folder
model_path = self.model_config.name_or_path
transformer_path = model_path
self.print_and_status_update("Loading transformer")
with torch.device("meta"):
transformer = Flux2(Flux2Params())
# use local path if provided
if os.path.exists(os.path.join(transformer_path, FLUX2_TRANSFORMER_FILENAME)):
transformer_path = os.path.join(
transformer_path, FLUX2_TRANSFORMER_FILENAME
)
if not os.path.exists(transformer_path):
# assume it is from the hub
transformer_path = huggingface_hub.hf_hub_download(
repo_id=model_path,
filename=FLUX2_TRANSFORMER_FILENAME,
token=HF_TOKEN,
)
transformer_state_dict = load_file(transformer_path, device="cpu")
# cast to dtype
for key in transformer_state_dict:
transformer_state_dict[key] = transformer_state_dict[key].to(dtype)
transformer.load_state_dict(transformer_state_dict, assign=True)
transformer.to(self.quantize_device, dtype=dtype)
if self.model_config.quantize:
# patch the state dict method
patch_dequantization_on_save(transformer)
self.print_and_status_update("Quantizing Transformer")
quantize_model(self, transformer)
flush()
else:
transformer.to(self.device_torch, dtype=dtype)
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_transformer_percent > 0
):
MemoryManager.attach(
transformer,
self.device_torch,
offload_percent=self.model_config.layer_offloading_transformer_percent,
)
if self.model_config.low_vram:
self.print_and_status_update("Moving transformer to CPU")
transformer.to("cpu")
self.print_and_status_update("Loading Mistral")
text_encoder: Mistral3ForConditionalGeneration = (
Mistral3ForConditionalGeneration.from_pretrained(
MISTRAL_PATH,
torch_dtype=dtype,
)
)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing Mistral")
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
freeze(text_encoder)
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_text_encoder_percent > 0
):
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
)
tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH)
self.print_and_status_update("Loading VAE")
vae_path = self.model_config.vae_path
if os.path.exists(os.path.join(model_path, FLUX2_VAE_FILENAME)):
vae_path = os.path.join(model_path, FLUX2_VAE_FILENAME)
if vae_path is None or not os.path.exists(vae_path):
# assume it is from the hub
vae_path = huggingface_hub.hf_hub_download(
repo_id=model_path,
filename=FLUX2_VAE_FILENAME,
token=HF_TOKEN,
)
with torch.device("meta"):
vae = AutoEncoder(AutoEncoderParams())
vae_state_dict = load_file(vae_path, device="cpu")
# cast to dtype
for key in vae_state_dict:
vae_state_dict[key] = vae_state_dict[key].to(dtype)
vae.load_state_dict(vae_state_dict, assign=True)
self.noise_scheduler = Flux2Model.get_train_scheduler()
self.print_and_status_update("Making pipe")
pipe: Flux2Pipeline = Flux2Pipeline(
scheduler=self.noise_scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
transformer=None,
)
# for quantization, it works best to do these after making the pipe
pipe.transformer = transformer
self.print_and_status_update("Preparing Model")
text_encoder = [pipe.text_encoder]
tokenizer = [pipe.tokenizer]
flush()
# just to make sure everything is on the right device and dtype
text_encoder[0].to(self.device_torch)
text_encoder[0].requires_grad_(False)
text_encoder[0].eval()
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
# save it to the model class
self.vae = vae
self.text_encoder = text_encoder # list of text encoders
self.tokenizer = tokenizer # list of tokenizers
self.model = pipe.transformer
self.pipeline = pipe
self.print_and_status_update("Model Loaded")
def get_generation_pipeline(self):
scheduler = Flux2Model.get_train_scheduler()
pipeline: Flux2Pipeline = Flux2Pipeline(
scheduler=scheduler,
text_encoder=unwrap_model(self.text_encoder[0]),
tokenizer=self.tokenizer[0],
vae=unwrap_model(self.vae),
transformer=unwrap_model(self.transformer),
)
pipeline = pipeline.to(self.device_torch)
return pipeline
def generate_single_image(
self,
pipeline: Flux2Pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
gen_config.width = (
gen_config.width // self.get_bucket_divisibility()
) * self.get_bucket_divisibility()
gen_config.height = (
gen_config.height // self.get_bucket_divisibility()
) * self.get_bucket_divisibility()
control_img_list = []
if gen_config.ctrl_img is not None:
control_img = Image.open(gen_config.ctrl_img)
control_img = control_img.convert("RGB")
control_img_list.append(control_img)
elif gen_config.ctrl_img_1 is not None:
control_img = Image.open(gen_config.ctrl_img_1)
control_img = control_img.convert("RGB")
control_img_list.append(control_img)
if gen_config.ctrl_img_2 is not None:
control_img = Image.open(gen_config.ctrl_img_2)
control_img = control_img.convert("RGB")
control_img_list.append(control_img)
if gen_config.ctrl_img_3 is not None:
control_img = Image.open(gen_config.ctrl_img_3)
control_img = control_img.convert("RGB")
control_img_list.append(control_img)
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
generator=generator,
control_img_list=control_img_list,
**extra,
).images[0]
return img
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
guidance_embedding_scale: float,
batch: "DataLoaderBatchDTO" = None,
**kwargs,
):
with torch.no_grad():
txt, txt_ids = batched_prc_txt(text_embeddings.text_embeds)
packed_latents, img_ids = batched_prc_img(latent_model_input)
# prepare image conditioning if any
img_cond_seq: torch.Tensor | None = None
img_cond_seq_ids: torch.Tensor | None = None
# handle control images
if batch.control_tensor_list is not None:
batch_size, num_channels_latents, height, width = (
latent_model_input.shape
)
control_image_max_res = 1024 * 1024
if self.model_config.model_kwargs.get("match_target_res", False):
# use the current target size to set the control image res
control_image_res = (
height
* self.pipeline.vae_scale_factor
* width
* self.pipeline.vae_scale_factor
)
control_image_max_res = control_image_res
if len(batch.control_tensor_list) != batch_size:
raise ValueError(
"Control tensor list length does not match batch size"
)
for control_tensor_list in batch.control_tensor_list:
# control tensor list is a list of tensors for this batch item
controls = []
# pack control
for control_img in control_tensor_list:
# control images are 0 - 1 scale, shape (1, ch, height, width)
control_img = control_img.to(
self.device_torch, dtype=self.torch_dtype
)
# if it is only 3 dim, add batch dim
if len(control_img.shape) == 3:
control_img = control_img.unsqueeze(0)
# resize to fit within max res while keeping aspect ratio
if self.model_config.model_kwargs.get(
"match_target_res", False
):
ratio = control_img.shape[2] / control_img.shape[3]
c_width = math.sqrt(control_image_res * ratio)
c_height = c_width / ratio
c_width = round(c_width / 32) * 32
c_height = round(c_height / 32) * 32
control_img = F.interpolate(
control_img, size=(c_height, c_width), mode="bilinear"
)
# scale to -1 to 1
control_img = control_img * 2 - 1
controls.append(control_img)
img_cond_seq_item, img_cond_seq_ids_item = encode_image_refs(
self.vae, controls, limit_pixels=control_image_max_res
)
if img_cond_seq is None:
img_cond_seq = img_cond_seq_item
img_cond_seq_ids = img_cond_seq_ids_item
else:
img_cond_seq = torch.cat(
(img_cond_seq, img_cond_seq_item), dim=0
)
img_cond_seq_ids = torch.cat(
(img_cond_seq_ids, img_cond_seq_ids_item), dim=0
)
img_input = packed_latents
img_input_ids = img_ids
if img_cond_seq is not None:
assert img_cond_seq_ids is not None, (
"You need to provide either both or neither of the sequence conditioning"
)
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
guidance_vec = torch.full(
(img_input.shape[0],),
guidance_embedding_scale,
device=img_input.device,
dtype=img_input.dtype,
)
cast_dtype = self.model.dtype
packed_noise_pred = self.transformer(
x=img_input.to(self.device_torch, cast_dtype),
x_ids=img_input_ids.to(self.device_torch),
timesteps=timestep.to(self.device_torch, cast_dtype) / 1000,
ctx=txt.to(self.device_torch, cast_dtype),
ctx_ids=txt_ids.to(self.device_torch),
guidance=guidance_vec.to(self.device_torch, cast_dtype),
)
if img_cond_seq is not None:
packed_noise_pred = packed_noise_pred[:, : packed_latents.shape[1]]
if isinstance(packed_noise_pred, QTensor):
packed_noise_pred = packed_noise_pred.dequantize()
noise_pred = torch.cat(scatter_ids(packed_noise_pred, img_ids)).squeeze(2)
return noise_pred
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
if self.pipeline.text_encoder.device != self.device_torch:
self.pipeline.text_encoder.to(self.device_torch)
prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
prompt, device=self.device_torch
)
pe = PromptEmbeds(prompt_embeds)
return pe
def get_model_has_grad(self):
return False
def get_te_has_grad(self):
return False
def save_model(self, output_path, meta, save_dtype):
if not output_path.endswith(".safetensors"):
output_path = output_path + ".safetensors"
# only save the unet
transformer: Flux2 = unwrap_model(self.model)
state_dict = transformer.state_dict()
save_dict = {}
for k, v in state_dict.items():
if isinstance(v, QTensor):
v = v.dequantize()
save_dict[k] = v.clone().to("cpu", dtype=save_dtype)
meta = get_meta_for_safetensors(meta, name="flux2")
save_file(save_dict, output_path, metadata=meta)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get("noise")
batch = kwargs.get("batch")
return (noise - batch.latents).detach()
def get_base_model_version(self):
return "flux2"
def get_transformer_block_names(self) -> Optional[List[str]]:
return ["double_blocks", "single_blocks"]
def convert_lora_weights_before_save(self, state_dict):
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("transformer.", "diffusion_model.")
new_sd[new_key] = value
return new_sd
def convert_lora_weights_before_load(self, state_dict):
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("diffusion_model.", "transformer.")
new_sd[new_key] = value
return new_sd
def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None):
if device is None:
device = self.vae_device_torch
if dtype is None:
dtype = self.vae_torch_dtype
# Move to vae to device if on cpu
if self.vae.device == torch.device("cpu"):
self.vae.to(device)
# move to device and dtype
image_list = [image.to(device, dtype=dtype) for image in image_list]
images = torch.stack(image_list).to(device, dtype=dtype)
latents = self.vae.encode(images)
return latents

View File

@@ -0,0 +1,370 @@
from dataclasses import dataclass, field
import torch
from einops import rearrange
from torch import Tensor, nn
import math
@dataclass
class AutoEncoderParams:
resolution: int = 256
in_channels: int = 3
ch: int = 128
out_ch: int = 3
ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
num_res_blocks: int = 2
z_channels: int = 32
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = nn.GroupNorm(
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h = x
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
# no asymmetric padding in torch conv, must do it ourselves
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x: Tensor):
pad = (0, 1, 0, 1)
x = nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x: Tensor):
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1)
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
)
self.conv_out = nn.Conv2d(
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
h = self.quant_conv(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z: Tensor) -> Tensor:
z = self.post_quant_conv(z)
# get dtype for proper tracing
upscale_dtype = next(self.up.parameters()).dtype
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# cast to proper dtype
h = h.to(upscale_dtype)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.params = params
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.bn_eps = 1e-4
self.bn_momentum = 0.1
self.ps = [2, 2]
self.bn = torch.nn.BatchNorm2d(
math.prod(self.ps) * params.z_channels,
eps=self.bn_eps,
momentum=self.bn_momentum,
affine=False,
track_running_stats=True,
)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def normalize(self, z):
self.bn.eval()
return self.bn(z)
def inv_normalize(self, z):
self.bn.eval()
s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps)
m = self.bn.running_mean.view(1, -1, 1, 1)
return z * s + m
def encode(self, x: Tensor) -> Tensor:
moments = self.encoder(x)
mean = torch.chunk(moments, 2, dim=1)[0]
z = rearrange(
mean,
"... c (i pi) (j pj) -> ... (c pi pj) i j",
pi=self.ps[0],
pj=self.ps[1],
)
z = self.normalize(z)
return z
def decode(self, z: Tensor) -> Tensor:
z = self.inv_normalize(z)
z = rearrange(
z,
"... (c pi pj) i j -> ... c (i pi) (j pj)",
pi=self.ps[0],
pj=self.ps[1],
)
dec = self.decoder(z)
return dec

View File

@@ -0,0 +1,520 @@
import torch
from einops import rearrange
from torch import Tensor, nn
import torch.utils.checkpoint as ckpt
import math
from dataclasses import dataclass, field
@dataclass
class Flux2Params:
in_channels: int = 128
context_in_dim: int = 15360
hidden_size: int = 6144
num_heads: int = 48
depth: int = 8
depth_single_blocks: int = 48
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
theta: int = 2000
mlp_ratio: float = 3.0
class FakeConfig:
# for diffusers compatability
def __init__(self):
self.patch_size = 1
class Flux2(nn.Module):
def __init__(self, params: Flux2Params):
super().__init__()
self.config = FakeConfig()
self.in_channels = params.in_channels
self.out_channels = params.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
)
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=False)
self.time_in = MLPEmbedder(
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
)
self.guidance_in = MLPEmbedder(
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
)
for _ in range(params.depth_single_blocks)
]
)
self.double_stream_modulation_img = Modulation(
self.hidden_size,
double=True,
disable_bias=True,
)
self.double_stream_modulation_txt = Modulation(
self.hidden_size,
double=True,
disable_bias=True,
)
self.single_stream_modulation = Modulation(
self.hidden_size, double=False, disable_bias=True
)
self.final_layer = LastLayer(
self.hidden_size,
self.out_channels,
)
self.gradient_checkpointing = False
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True
def forward(
self,
x: Tensor,
x_ids: Tensor,
timesteps: Tensor,
ctx: Tensor,
ctx_ids: Tensor,
guidance: Tensor,
):
num_txt_tokens = ctx.shape[1]
timestep_emb = timestep_embedding(timesteps, 256)
vec = self.time_in(timestep_emb)
guidance_emb = timestep_embedding(guidance, 256)
vec = vec + self.guidance_in(guidance_emb)
double_block_mod_img = self.double_stream_modulation_img(vec)
double_block_mod_txt = self.double_stream_modulation_txt(vec)
single_block_mod, _ = self.single_stream_modulation(vec)
img = self.img_in(x)
txt = self.txt_in(ctx)
pe_x = self.pe_embedder(x_ids)
pe_ctx = self.pe_embedder(ctx_ids)
for block in self.double_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
img.requires_grad_(True)
txt.requires_grad_(True)
img, txt = ckpt.checkpoint(
block,
img,
txt,
pe_x,
pe_ctx,
double_block_mod_img,
double_block_mod_txt,
)
else:
img, txt = block(
img,
txt,
pe_x,
pe_ctx,
double_block_mod_img,
double_block_mod_txt,
)
img = torch.cat((txt, img), dim=1)
pe = torch.cat((pe_ctx, pe_x), dim=2)
for i, block in enumerate(self.single_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
img.requires_grad_(True)
img = ckpt.checkpoint(
block,
img,
pe,
single_block_mod,
)
else:
img = block(
img,
pe,
single_block_mod,
)
img = img[:, num_txt_tokens:, ...]
img = self.final_layer(img, vec)
return img
class SelfAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim, bias=False)
class SiLUActivation(nn.Module):
def __init__(self):
super().__init__()
self.gate_fn = nn.SiLU()
def forward(self, x: Tensor) -> Tensor:
x1, x2 = x.chunk(2, dim=-1)
return self.gate_fn(x1) * x2
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool, disable_bias: bool = False):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=not disable_bias)
def forward(self, vec: torch.Tensor):
out = self.lin(nn.functional.silu(vec))
if out.ndim == 2:
out = out[:, None, :]
out = out.chunk(self.multiplier, dim=-1)
return out[:3], out[3:] if self.is_double else None
class LastLayer(nn.Module):
def __init__(
self,
hidden_size: int,
out_channels: int,
):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=False)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=False)
)
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
mod = self.adaLN_modulation(vec)
shift, scale = mod.chunk(2, dim=-1)
if shift.ndim == 2:
shift = shift[:, None, :]
scale = scale[:, None, :]
x = (1 + scale) * self.norm_final(x) + shift
x = self.linear(x)
return x
class SingleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_mult_factor = 2
self.linear1 = nn.Linear(
hidden_size,
hidden_size * 3 + self.mlp_hidden_dim * self.mlp_mult_factor,
bias=False,
)
self.linear2 = nn.Linear(
hidden_size + self.mlp_hidden_dim, hidden_size, bias=False
)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = SiLUActivation()
def forward(
self,
x: Tensor,
pe: Tensor,
mod: tuple[Tensor, Tensor],
) -> Tensor:
mod_shift, mod_scale, mod_gate = mod
x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift
qkv, mlp = torch.split(
self.linear1(x_mod),
[3 * self.hidden_size, self.mlp_hidden_dim * self.mlp_mult_factor],
dim=-1,
)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
attn = attention(q, k, v, pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod_gate * output
class DoubleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float,
):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
assert hidden_size % num_heads == 0, (
f"{hidden_size=} must be divisible by {num_heads=}"
)
self.hidden_size = hidden_size
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_mult_factor = 2
self.img_attn = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim * self.mlp_mult_factor, bias=False),
SiLUActivation(),
nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(
hidden_size,
mlp_hidden_dim * self.mlp_mult_factor,
bias=False,
),
SiLUActivation(),
nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
)
def forward(
self,
img: Tensor,
txt: Tensor,
pe: Tensor,
pe_ctx: Tensor,
mod_img: tuple[Tensor, Tensor],
mod_txt: tuple[Tensor, Tensor],
) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = mod_img
txt_mod1, txt_mod2 = mod_txt
img_mod1_shift, img_mod1_scale, img_mod1_gate = img_mod1
img_mod2_shift, img_mod2_scale, img_mod2_gate = img_mod2
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate = txt_mod1
txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = txt_mod2
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(
img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(
txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
pe = torch.cat((pe_ctx, pe), dim=2)
attn = attention(q, k, v, pe)
txt_attn, img_attn = attn[:, : txt_q.shape[2]], attn[:, txt_q.shape[2] :]
# calculate the img blocks
img = img + img_mod1_gate * self.img_attn.proj(img_attn)
img = img + img_mod2_gate * self.img_mlp(
(1 + img_mod2_scale) * (self.img_norm2(img)) + img_mod2_shift
)
# calculate the txt blocks
txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2_gate * self.txt_mlp(
(1 + txt_mod2_scale) * (self.txt_norm2(txt)) + txt_mod2_shift
)
return img, txt
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, disable_bias: bool = False):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=not disable_bias)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=not disable_bias)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
emb = torch.cat(
[
rope(ids[..., i], self.axes_dim[i], self.theta)
for i in range(len(self.axes_dim))
],
dim=-3,
)
return emb.unsqueeze(1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, device=t.device, dtype=torch.float32)
/ half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * self.scale
class QKNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack(
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@@ -0,0 +1,358 @@
from typing import List, Optional, Union
import numpy as np
import torch
import PIL.Image
from dataclasses import dataclass
from typing import List, Union
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
logging,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import BaseOutput
from .autoencoder import AutoEncoder
from .model import Flux2
from einops import rearrange
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
from .sampling import (
get_schedule,
batched_prc_img,
batched_prc_txt,
encode_image_refs,
scatter_ids,
)
@dataclass
class Flux2ImagePipelineOutput(BaseOutput):
images: Union[List[PIL.Image.Image], np.ndarray]
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
attribution and actions without speculation."""
OUTPUT_LAYERS = [10, 20, 30]
MAX_LENGTH = 512
class Flux2Pipeline(DiffusionPipeline):
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoEncoder,
text_encoder: Mistral3ForConditionalGeneration,
tokenizer: AutoProcessor,
transformer: Flux2,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 16 # 8x plus 2x pixel shuffle
self.num_channels_latents = 128
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = 64
def format_input(
self,
txt: list[str],
) -> list[list[dict]]:
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
# when truncation is enabled. The processor counts [IMG] tokens and fails
# if the count changes after truncation.
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in txt]
return [
[
{
"role": "system",
"content": [{"type": "text", "text": SYSTEM_MESSAGE}],
},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
for prompt in cleaned_txt
]
def _get_mistral_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 512,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
if not isinstance(prompt, list):
prompt = [prompt]
# Format input messages
messages_batch = self.format_input(txt=prompt)
# Process all messages at once
# with image processing a too short max length can throw an error in here.
try:
inputs = self.tokenizer.apply_chat_template(
messages_batch,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
except ValueError as e:
print(
f"Error processing input: {e}, your max length is probably too short, when you have images in the input."
)
raise e
# Move to device
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
# Forward pass through the model
output = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1)
prompt_embeds = rearrange(out, "b c l d -> b l (c d)")
# they don't return attention mask, so we create it here
return prompt_embeds, None
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
):
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_mistral_prompt_embeds(
prompt, device, max_sequence_length=max_sequence_length
)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
return prompt_embeds, prompt_embeds_mask
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
return latents.to(device=device, dtype=dtype)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: Optional[float] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
max_sequence_length: int = 512,
control_img_list: Optional[List[PIL.Image.Image]] = None,
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Encode the prompt
prompt_embeds, _ = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
txt, txt_ids = batched_prc_txt(prompt_embeds)
# 4. Prepare latent variables\
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
self.num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
packed_latents, img_ids = batched_prc_img(latents)
timesteps = get_schedule(num_inference_steps, packed_latents.shape[1])
self._num_timesteps = len(timesteps)
guidance_vec = torch.full(
(packed_latents.shape[0],),
guidance_scale,
device=packed_latents.device,
dtype=packed_latents.dtype,
)
if control_img_list is not None and len(control_img_list) > 0:
img_cond_seq, img_cond_seq_ids = encode_image_refs(
self.vae, control_img_list
)
else:
img_cond_seq, img_cond_seq_ids = None, None
# 6. Denoising loop
i = 0
with self.progress_bar(total=num_inference_steps) as progress_bar:
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
if self.interrupt:
continue
t_vec = torch.full(
(packed_latents.shape[0],),
t_curr,
dtype=packed_latents.dtype,
device=packed_latents.device,
)
self._current_timestep = t_curr
img_input = packed_latents
img_input_ids = img_ids
if img_cond_seq is not None:
assert img_cond_seq_ids is not None, (
"You need to provide either both or neither of the sequence conditioning"
)
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
pred = self.transformer(
x=img_input,
x_ids=img_input_ids,
timesteps=t_vec,
ctx=txt,
ctx_ids=txt_ids,
guidance=guidance_vec,
)
if img_cond_seq is not None:
pred = pred[:, : packed_latents.shape[1]]
packed_latents = packed_latents + (t_prev - t_curr) * pred
i += 1
progress_bar.update(1)
self._current_timestep = None
# 7. Post-processing
latents = torch.cat(scatter_ids(packed_latents, img_ids)).squeeze(2)
if output_type == "latent":
image = latents
else:
latents = latents.to(self.vae.dtype)
image = self.vae.decode(latents).float()
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return Flux2ImagePipelineOutput(images=image)

View File

@@ -0,0 +1,365 @@
import math
from typing import Callable, Union
import torch
from einops import rearrange
from PIL import Image
from torch import Tensor
from .model import Flux2
import torchvision
def compress_time(t_ids: Tensor) -> Tensor:
assert t_ids.ndim == 1
t_ids_max = torch.max(t_ids)
t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype)
t_unique_sorted_ids = torch.unique(t_ids, sorted=True)
t_remap[t_unique_sorted_ids] = torch.arange(
len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype
)
t_ids_compressed = t_remap[t_ids]
return t_ids_compressed
def scatter_ids(x: Tensor, x_ids: Tensor) -> list[Tensor]:
"""
using position ids to scatter tokens into place
"""
x_list = []
t_coords = []
for data, pos in zip(x, x_ids):
_, ch = data.shape # noqa: F841
t_ids = pos[:, 0].to(torch.int64)
h_ids = pos[:, 1].to(torch.int64)
w_ids = pos[:, 2].to(torch.int64)
t_ids_cmpr = compress_time(t_ids)
t = torch.max(t_ids_cmpr) + 1
h = torch.max(h_ids) + 1
w = torch.max(w_ids) + 1
flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids
out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype)
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w))
t_coords.append(torch.unique(t_ids, sorted=True))
return x_list
def encode_image_refs(
ae,
img_ctx: Union[list[Image.Image], list[torch.Tensor]],
scale=10,
limit_pixels=1024**2,
):
if not img_ctx:
return None, None
img_ctx_prep = default_prep(img=img_ctx, limit_pixels=limit_pixels)
if not isinstance(img_ctx_prep, list):
img_ctx_prep = [img_ctx_prep]
# Encode each reference image
encoded_refs = []
for img in img_ctx_prep:
if img.ndim == 3:
img = img.unsqueeze(0)
encoded = ae.encode(img.to(ae.device, ae.dtype))[0]
encoded_refs.append(encoded)
# Create time offsets for each reference
t_off = [scale + scale * t for t in torch.arange(0, len(encoded_refs))]
t_off = [t.view(-1) for t in t_off]
# Process with position IDs
ref_tokens, ref_ids = listed_prc_img(encoded_refs, t_coord=t_off)
# Concatenate all references along sequence dimension
ref_tokens = torch.cat(ref_tokens, dim=0) # (total_ref_tokens, C)
ref_ids = torch.cat(ref_ids, dim=0) # (total_ref_tokens, 4)
# Add batch dimension
ref_tokens = ref_tokens.unsqueeze(0) # (1, total_ref_tokens, C)
ref_ids = ref_ids.unsqueeze(0) # (1, total_ref_tokens, 4)
return ref_tokens.to(torch.bfloat16), ref_ids
def prc_txt(
x: Tensor, t_coord: Tensor | None = None, l_coord: Tensor | None = None
) -> tuple[Tensor, Tensor]:
assert l_coord is None, "l_coord not supported for txts"
_l, _ = x.shape # noqa: F841
coords = {
"t": torch.arange(1) if t_coord is None else t_coord,
"h": torch.arange(1), # dummy dimension
"w": torch.arange(1), # dummy dimension
"l": torch.arange(_l),
}
x_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"])
return x, x_ids.to(x.device)
def batched_wrapper(fn):
def batched_prc(
x: Tensor, t_coord: Tensor | None = None, l_coord: Tensor | None = None
) -> tuple[Tensor, Tensor]:
results = []
for i in range(len(x)):
results.append(
fn(
x[i],
t_coord[i] if t_coord is not None else None,
l_coord[i] if l_coord is not None else None,
)
)
x, x_ids = zip(*results)
return torch.stack(x), torch.stack(x_ids)
return batched_prc
def listed_wrapper(fn):
def listed_prc(
x: list[Tensor],
t_coord: list[Tensor] | None = None,
l_coord: list[Tensor] | None = None,
) -> tuple[list[Tensor], list[Tensor]]:
results = []
for i in range(len(x)):
results.append(
fn(
x[i],
t_coord[i] if t_coord is not None else None,
l_coord[i] if l_coord is not None else None,
)
)
x, x_ids = zip(*results)
return list(x), list(x_ids)
return listed_prc
def prc_img(
x: Tensor, t_coord: Tensor | None = None, l_coord: Tensor | None = None
) -> tuple[Tensor, Tensor]:
c, h, w = x.shape # noqa: F841
x_coords = {
"t": torch.arange(1) if t_coord is None else t_coord,
"h": torch.arange(h),
"w": torch.arange(w),
"l": torch.arange(1) if l_coord is None else l_coord,
}
x_ids = torch.cartesian_prod(
x_coords["t"], x_coords["h"], x_coords["w"], x_coords["l"]
)
x = rearrange(x, "c h w -> (h w) c")
return x, x_ids.to(x.device)
listed_prc_img = listed_wrapper(prc_img)
batched_prc_img = batched_wrapper(prc_img)
batched_prc_txt = batched_wrapper(prc_txt)
def center_crop_to_multiple_of_x(
img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor], x: int
) -> Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor]:
if isinstance(img, list):
return [center_crop_to_multiple_of_x(_img, x) for _img in img] # type: ignore
if isinstance(img, torch.Tensor):
h, w = img.shape[-2], img.shape[-1]
else:
w, h = img.size
new_w = (w // x) * x
new_h = (h // x) * x
left = (w - new_w) // 2
top = (h - new_h) // 2
right = left + new_w
bottom = top + new_h
if isinstance(img, torch.Tensor):
return img[..., top:bottom, left:right]
resized = img.crop((left, top, right, bottom))
return resized
def cap_pixels(
img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor], k
):
if isinstance(img, list):
return [cap_pixels(_img, k) for _img in img]
if isinstance(img, torch.Tensor):
h, w = img.shape[-2], img.shape[-1]
else:
w, h = img.size
pixel_count = w * h
if pixel_count <= k:
return img
# Scaling factor to reduce total pixels below K
scale = math.sqrt(k / pixel_count)
new_w = int(w * scale)
new_h = int(h * scale)
if isinstance(img, torch.Tensor):
did_expand = False
if img.ndim == 3:
img = img.unsqueeze(0)
did_expand = True
img = torch.nn.functional.interpolate(
img,
size=(new_h, new_w),
mode="bicubic",
align_corners=False,
)
if did_expand:
img = img.squeeze(0)
return img
return img.resize((new_w, new_h), Image.Resampling.LANCZOS)
def cap_min_pixels(
img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor],
max_ar=8,
min_sidelength=64,
):
if isinstance(img, list):
return [
cap_min_pixels(_img, max_ar=max_ar, min_sidelength=min_sidelength)
for _img in img
]
if isinstance(img, torch.Tensor):
h, w = img.shape[-2], img.shape[-1]
else:
w, h = img.size
if w < min_sidelength or h < min_sidelength:
raise ValueError(
f"Skipping due to minimal sidelength underschritten h {h} w {w}"
)
if w / h > max_ar or h / w > max_ar:
raise ValueError(f"Skipping due to maximal ar overschritten h {h} w {w}")
return img
def to_rgb(
img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor],
) -> Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor]:
if isinstance(img, list):
return [
to_rgb(
_img,
)
for _img in img
]
if isinstance(img, torch.Tensor):
return img # assume already in tensor format
return img.convert("RGB")
def default_images_prep(
x: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor],
) -> torch.Tensor | list[torch.Tensor]:
if isinstance(x, list):
return [default_images_prep(e) for e in x] # type: ignore
if isinstance(x, torch.Tensor):
return x # assume already in tensor format
x_tensor = torchvision.transforms.ToTensor()(x)
return 2 * x_tensor - 1
def default_prep(
img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor],
limit_pixels: int,
ensure_multiple: int = 16,
) -> torch.Tensor | list[torch.Tensor]:
# if passing a tensor, assume it is -1 to 1 already
img_rgb = to_rgb(img)
img_min = cap_min_pixels(img_rgb) # type: ignore
img_cap = cap_pixels(img_min, limit_pixels) # type: ignore
img_crop = center_crop_to_multiple_of_x(img_cap, ensure_multiple) # type: ignore
img_tensor = default_images_prep(img_crop)
return img_tensor
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# estimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux2,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float,
# extra img tokens (sequence-wise)
img_cond_seq: Tensor | None = None,
img_cond_seq_ids: Tensor | None = None,
):
guidance_vec = torch.full(
(img.shape[0],), guidance, device=img.device, dtype=img.dtype
)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
img_input = img
img_input_ids = img_ids
if img_cond_seq is not None:
assert img_cond_seq_ids is not None, (
"You need to provide either both or neither of the sequence conditioning"
)
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
pred = model(
x=img_input,
x_ids=img_input_ids,
timesteps=t_vec,
ctx=txt,
ctx_ids=txt_ids,
guidance=guidance_vec,
)
if img_input_ids is not None:
pred = pred[:, : img.shape[1]]
img = img + (t_prev - t_curr) * pred
return img

View File

@@ -258,7 +258,13 @@ export const modelArchs: ModelArch[] = [
],
},
disableSections: ['network.conv'],
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'model.multistage', 'model.layer_offloading'],
additionalSections: [
'sample.ctrl_img',
'datasets.num_frames',
'model.low_vram',
'model.multistage',
'model.layer_offloading',
],
accuracyRecoveryAdapters: {
'4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_i2v_torchao_uint4.safetensors',
},
@@ -459,6 +465,37 @@ export const modelArchs: ModelArch[] = [
disableSections: ['network.conv'],
additionalSections: ['datasets.control_path', 'sample.ctrl_img'],
},
{
name: 'flux2',
label: 'FLUX.2',
group: 'image',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.2-dev', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.low_vram': [true, false],
'config.process[0].train.unload_text_encoder': [false, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
'config.process[0].model.model_kwargs': [
{
match_target_res: false,
},
{},
],
},
disableSections: ['network.conv'],
additionalSections: [
'datasets.multi_control_paths',
'sample.multi_ctrl_imgs',
'model.low_vram',
'model.layer_offloading',
'model.qie.match_target_res',
],
},
].sort((a, b) => {
// Sort by label, case-insensitive
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' });

View File

@@ -1 +1 @@
VERSION = "0.7.4"
VERSION = "0.7.5"