Add support for FLUX.2 klein base models

This commit is contained in:
Jaret Burkett
2026-01-17 17:46:25 -07:00
parent 0efed794b4
commit a6da9e37ac
8 changed files with 361 additions and 56 deletions

View File

@@ -5,7 +5,7 @@ from .omnigen2 import OmniGen2Model
from .flux_kontext import FluxKontextModel
from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel
from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel
from .flux2 import Flux2Model
from .flux2 import Flux2Model, Flux2Klein4BModel, Flux2Klein9BModel
from .z_image import ZImageModel
from .ltx2 import LTX2Model
@@ -27,4 +27,6 @@ AI_TOOLKIT_MODELS = [
Flux2Model,
ZImageModel,
LTX2Model,
Flux2Klein4BModel,
Flux2Klein9BModel,
]

View File

@@ -1 +1,2 @@
from .flux2_model import Flux2Model
from .flux2_model import Flux2Model
from .flux2_klein_model import Flux2Klein4BModel, Flux2Klein9BModel

View File

@@ -0,0 +1,92 @@
from .flux2_model import Flux2Model
from transformers import Qwen3ForCausalLM, Qwen2Tokenizer
from optimum.quanto import freeze
from toolkit.util.quantize import quantize, get_qtype
from toolkit.config_modules import ModelConfig
from toolkit.memory_management.manager import MemoryManager
from toolkit.basic import flush
from .src.model import Klein9BParams, Klein4BParams
class Flux2KleinModel(Flux2Model):
flux2_klein_te_path: str = None
flux2_te_type: str = "qwen" # "mistral" or "qwen"
flux2_vae_path: str = "ai-toolkit/flux2_vae"
flux2_is_guidance_distilled: bool = False
def __init__(
self,
device,
model_config: ModelConfig,
dtype="bf16",
custom_pipeline=None,
noise_scheduler=None,
**kwargs,
):
super().__init__(
device,
model_config,
dtype,
custom_pipeline,
noise_scheduler,
**kwargs,
)
# use the new format on this new model by default
self.use_old_lokr_format = False
def load_te(self):
if self.flux2_klein_te_path is None:
raise ValueError("flux2_klein_te_path must be set for Flux2KleinModel")
dtype = self.torch_dtype
self.print_and_status_update("Loading Qwen3")
text_encoder: Qwen3ForCausalLM = Qwen3ForCausalLM.from_pretrained(
self.flux2_klein_te_path,
torch_dtype=dtype,
)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing Qwen3")
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
freeze(text_encoder)
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_text_encoder_percent > 0
):
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
)
tokenizer = Qwen2Tokenizer.from_pretrained(self.flux2_klein_te_path)
return text_encoder, tokenizer
class Flux2Klein4BModel(Flux2KleinModel):
arch = "flux2_klein_4b"
flux2_klein_te_path: str = "Qwen/Qwen3-4B"
flux2_te_filename: str = "flux-2-klein-base-4b.safetensors"
def get_flux2_params(self):
return Klein4BParams()
def get_base_model_version(self):
return "flux2_klein_4b"
class Flux2Klein9BModel(Flux2KleinModel):
arch = "flux2_klein_9b"
flux2_klein_te_path: str = "Qwen/Qwen3-8B"
flux2_te_filename: str = "flux-2-klein-base-9b.safetensors"
def get_flux2_params(self):
return Klein9BParams()
def get_base_model_version(self):
return "flux2_klein_9b"

View File

@@ -55,6 +55,10 @@ HF_TOKEN = os.getenv("HF_TOKEN", None)
class Flux2Model(BaseModel):
arch = "flux2"
flux2_te_type: str = "mistral" # "mistral" or "qwen"
flux2_vae_path: str = None
flux2_te_filename: str = FLUX2_TRANSFORMER_FILENAME
flux2_is_guidance_distilled: bool = True
def __init__(
self,
@@ -84,6 +88,42 @@ class Flux2Model(BaseModel):
def get_bucket_divisibility(self):
return 16
def get_flux2_params(self):
return Flux2Params()
def load_te(self):
dtype = self.torch_dtype
self.print_and_status_update("Loading Mistral")
text_encoder: Mistral3ForConditionalGeneration = (
Mistral3ForConditionalGeneration.from_pretrained(
MISTRAL_PATH,
torch_dtype=dtype,
)
)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing Mistral")
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
freeze(text_encoder)
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_text_encoder_percent > 0
):
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
)
tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH)
return text_encoder, tokenizer
def load_model(self):
dtype = self.torch_dtype
self.print_and_status_update("Loading Flux2 model")
@@ -93,19 +133,17 @@ class Flux2Model(BaseModel):
self.print_and_status_update("Loading transformer")
with torch.device("meta"):
transformer = Flux2(Flux2Params())
transformer = Flux2(self.get_flux2_params())
# use local path if provided
if os.path.exists(os.path.join(transformer_path, FLUX2_TRANSFORMER_FILENAME)):
transformer_path = os.path.join(
transformer_path, FLUX2_TRANSFORMER_FILENAME
)
if os.path.exists(os.path.join(transformer_path, self.flux2_te_filename)):
transformer_path = os.path.join(transformer_path, self.flux2_te_filename)
if not os.path.exists(transformer_path):
# assume it is from the hub
transformer_path = huggingface_hub.hf_hub_download(
repo_id=model_path,
filename=FLUX2_TRANSFORMER_FILENAME,
filename=self.flux2_te_filename,
token=HF_TOKEN,
)
@@ -143,35 +181,7 @@ class Flux2Model(BaseModel):
self.print_and_status_update("Moving transformer to CPU")
transformer.to("cpu")
self.print_and_status_update("Loading Mistral")
text_encoder: Mistral3ForConditionalGeneration = (
Mistral3ForConditionalGeneration.from_pretrained(
MISTRAL_PATH,
torch_dtype=dtype,
)
)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing Mistral")
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
freeze(text_encoder)
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_text_encoder_percent > 0
):
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
)
tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH)
text_encoder, tokenizer = self.load_te()
self.print_and_status_update("Loading VAE")
vae_path = self.model_config.vae_path
@@ -179,10 +189,14 @@ class Flux2Model(BaseModel):
if os.path.exists(os.path.join(model_path, FLUX2_VAE_FILENAME)):
vae_path = os.path.join(model_path, FLUX2_VAE_FILENAME)
if vae_path is None:
vae_path = self.flux2_vae_path
if vae_path is None or not os.path.exists(vae_path):
p = vae_path if vae_path is not None else model_path
# assume it is from the hub
vae_path = huggingface_hub.hf_hub_download(
repo_id=model_path,
repo_id=p,
filename=FLUX2_VAE_FILENAME,
token=HF_TOKEN,
)
@@ -207,6 +221,8 @@ class Flux2Model(BaseModel):
tokenizer=tokenizer,
vae=vae,
transformer=None,
text_encoder_type=self.flux2_te_type,
is_guidance_distilled=self.flux2_is_guidance_distilled,
)
# for quantization, it works best to do these after making the pipe
pipe.transformer = transformer
@@ -241,6 +257,8 @@ class Flux2Model(BaseModel):
tokenizer=self.tokenizer[0],
vae=unwrap_model(self.vae),
transformer=unwrap_model(self.transformer),
text_encoder_type=self.flux2_te_type,
is_guidance_distilled=self.flux2_is_guidance_distilled,
)
pipeline = pipeline.to(self.device_torch)
@@ -281,6 +299,9 @@ class Flux2Model(BaseModel):
control_img = control_img.convert("RGB")
control_img_list.append(control_img)
if not self.flux2_is_guidance_distilled:
extra["negative_prompt_embeds"] = unconditional_embeds.text_embeds
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds,
height=gen_config.height,

View File

@@ -17,6 +17,35 @@ class Flux2Params:
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
theta: int = 2000
mlp_ratio: float = 3.0
use_guidance_embed: bool = True
@dataclass
class Klein9BParams:
in_channels: int = 128
context_in_dim: int = 12288
hidden_size: int = 4096
num_heads: int = 32
depth: int = 8
depth_single_blocks: int = 24
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
theta: int = 2000
mlp_ratio: float = 3.0
use_guidance_embed: bool = False
@dataclass
class Klein4BParams:
in_channels: int = 128
context_in_dim: int = 7680
hidden_size: int = 3072
num_heads: int = 24
depth: int = 5
depth_single_blocks: int = 20
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
theta: int = 2000
mlp_ratio: float = 3.0
use_guidance_embed: bool = False
class FakeConfig:
@@ -50,11 +79,14 @@ class Flux2(nn.Module):
self.time_in = MLPEmbedder(
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
)
self.guidance_in = MLPEmbedder(
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False)
self.use_guidance_embed = params.use_guidance_embed
if self.use_guidance_embed:
self.guidance_in = MLPEmbedder(
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
@@ -116,14 +148,15 @@ class Flux2(nn.Module):
timesteps: Tensor,
ctx: Tensor,
ctx_ids: Tensor,
guidance: Tensor,
guidance: Tensor | None,
):
num_txt_tokens = ctx.shape[1]
timestep_emb = timestep_embedding(timesteps, 256)
vec = self.time_in(timestep_emb)
guidance_emb = timestep_embedding(guidance, 256)
vec = vec + self.guidance_in(guidance_emb)
if self.use_guidance_embed:
guidance_emb = timestep_embedding(guidance, 256)
vec = vec + self.guidance_in(guidance_emb)
double_block_mod_img = self.double_stream_modulation_img(vec)
double_block_mod_txt = self.double_stream_modulation_txt(vec)

View File

@@ -4,8 +4,6 @@ import numpy as np
import torch
import PIL.Image
from dataclasses import dataclass
from typing import List, Union
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
@@ -13,14 +11,10 @@ from diffusers.utils import (
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import BaseOutput
from .autoencoder import AutoEncoder
from .model import Flux2
from einops import rearrange
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
from .sampling import (
@@ -41,7 +35,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
attribution and actions without speculation."""
OUTPUT_LAYERS = [10, 20, 30]
OUTPUT_LAYERS_MISTRAL = [10, 20, 30]
OUTPUT_LAYERS_QWEN3 = [9, 18, 27]
MAX_LENGTH = 512
@@ -56,6 +51,8 @@ class Flux2Pipeline(DiffusionPipeline):
text_encoder: Mistral3ForConditionalGeneration,
tokenizer: AutoProcessor,
transformer: Flux2,
text_encoder_type: str = "mistral", # "mistral" or "qwen"
is_guidance_distilled: bool = False,
):
super().__init__()
@@ -70,6 +67,8 @@ class Flux2Pipeline(DiffusionPipeline):
self.num_channels_latents = 128
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = 64
self.text_encoder_type = text_encoder_type
self.is_guidance_distilled = is_guidance_distilled
def format_input(
self,
@@ -138,12 +137,66 @@ class Flux2Pipeline(DiffusionPipeline):
use_cache=False,
)
out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1)
out = torch.stack(
[output.hidden_states[k] for k in OUTPUT_LAYERS_MISTRAL], dim=1
)
prompt_embeds = rearrange(out, "b c l d -> b l (c d)")
# they don't return attention mask, so we create it here
return prompt_embeds, None
def _get_qwen_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 512,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
if not isinstance(prompt, list):
prompt = [prompt]
all_input_ids = []
all_attention_masks = []
for p in prompt:
messages = [{"role": "user", "content": p}]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
model_inputs = self.tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
all_input_ids.append(model_inputs["input_ids"])
all_attention_masks.append(model_inputs["attention_mask"])
input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
output = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS_QWEN3], dim=1)
prompt_embeds = rearrange(out, "b c l d -> b l (c d)")
# they dont use attention mask
return prompt_embeds, None
def encode_prompt(
self,
prompt: Union[str, List[str]],
@@ -159,9 +212,18 @@ class Flux2Pipeline(DiffusionPipeline):
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_mistral_prompt_embeds(
prompt, device, max_sequence_length=max_sequence_length
)
if self.text_encoder_type == "mistral":
prompt_embeds, prompt_embeds_mask = self._get_mistral_prompt_embeds(
prompt, device, max_sequence_length=max_sequence_length
)
elif self.text_encoder_type == "qwen":
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
prompt, device, max_sequence_length=max_sequence_length
)
else:
raise ValueError(
f"Unsupported text_encoder_type: {self.text_encoder_type}"
)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -220,6 +282,7 @@ class Flux2Pipeline(DiffusionPipeline):
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
@@ -229,6 +292,8 @@ class Flux2Pipeline(DiffusionPipeline):
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
max_sequence_length: int = 512,
@@ -236,6 +301,11 @@ class Flux2Pipeline(DiffusionPipeline):
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
do_guidance = (
guidance_scale is not None
and guidance_scale > 1.0
and not self.is_guidance_distilled
)
self._guidance_scale = guidance_scale
self._current_timestep = None
@@ -263,6 +333,19 @@ class Flux2Pipeline(DiffusionPipeline):
)
txt, txt_ids = batched_prc_txt(prompt_embeds)
neg_txt, neg_txt_ids = None, None
if do_guidance:
negative_prompt_embeds, _ = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
neg_txt, neg_txt_ids = batched_prc_txt(negative_prompt_embeds)
# 4. Prepare latent variables\
latents = self.prepare_latents(
@@ -329,6 +412,17 @@ class Flux2Pipeline(DiffusionPipeline):
guidance=guidance_vec,
)
if do_guidance:
pred_uncond = self.transformer(
x=img_input,
x_ids=img_input_ids,
timesteps=t_vec,
ctx=neg_txt,
ctx_ids=neg_txt_ids,
guidance=guidance_vec,
)
pred = pred_uncond + guidance_scale * (pred - pred_uncond)
if img_cond_seq is not None:
pred = pred[:, : packed_latents.shape[1]]

View File

@@ -630,6 +630,68 @@ export const modelArchs: ModelArch[] = [
disableSections: ['network.conv'],
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch', 'datasets.do_i2v'],
},
{
name: 'flux2_klein_4b',
label: 'FLUX.2-klein-base-4B',
group: 'image',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.2-klein-base-4B', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.low_vram': [true, false],
'config.process[0].train.unload_text_encoder': [false, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
'config.process[0].model.model_kwargs': [
{
match_target_res: false,
},
{},
],
},
disableSections: ['network.conv'],
additionalSections: [
'datasets.multi_control_paths',
'sample.multi_ctrl_imgs',
'model.low_vram',
'model.layer_offloading',
'model.qie.match_target_res',
],
},
{
name: 'flux2_klein_9b',
label: 'FLUX.2-klein-base-9B',
group: 'image',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.2-klein-base-9B', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.low_vram': [true, false],
'config.process[0].train.unload_text_encoder': [false, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
'config.process[0].model.model_kwargs': [
{
match_target_res: false,
},
{},
],
},
disableSections: ['network.conv'],
additionalSections: [
'datasets.multi_control_paths',
'sample.multi_ctrl_imgs',
'model.low_vram',
'model.layer_offloading',
'model.qie.match_target_res',
],
},
].sort((a, b) => {
// Sort by label, case-insensitive
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' });

View File

@@ -1 +1 @@
VERSION = "0.7.19"
VERSION = "0.7.20"