Add support for FLUX.2 klein base models

This commit is contained in:
Jaret Burkett
2026-01-17 17:46:25 -07:00
parent 0efed794b4
commit a6da9e37ac
8 changed files with 361 additions and 56 deletions

View File

@@ -5,7 +5,7 @@ from .omnigen2 import OmniGen2Model
from .flux_kontext import FluxKontextModel from .flux_kontext import FluxKontextModel
from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel
from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel
from .flux2 import Flux2Model from .flux2 import Flux2Model, Flux2Klein4BModel, Flux2Klein9BModel
from .z_image import ZImageModel from .z_image import ZImageModel
from .ltx2 import LTX2Model from .ltx2 import LTX2Model
@@ -27,4 +27,6 @@ AI_TOOLKIT_MODELS = [
Flux2Model, Flux2Model,
ZImageModel, ZImageModel,
LTX2Model, LTX2Model,
Flux2Klein4BModel,
Flux2Klein9BModel,
] ]

View File

@@ -1 +1,2 @@
from .flux2_model import Flux2Model from .flux2_model import Flux2Model
from .flux2_klein_model import Flux2Klein4BModel, Flux2Klein9BModel

View File

@@ -0,0 +1,92 @@
from .flux2_model import Flux2Model
from transformers import Qwen3ForCausalLM, Qwen2Tokenizer
from optimum.quanto import freeze
from toolkit.util.quantize import quantize, get_qtype
from toolkit.config_modules import ModelConfig
from toolkit.memory_management.manager import MemoryManager
from toolkit.basic import flush
from .src.model import Klein9BParams, Klein4BParams
class Flux2KleinModel(Flux2Model):
flux2_klein_te_path: str = None
flux2_te_type: str = "qwen" # "mistral" or "qwen"
flux2_vae_path: str = "ai-toolkit/flux2_vae"
flux2_is_guidance_distilled: bool = False
def __init__(
self,
device,
model_config: ModelConfig,
dtype="bf16",
custom_pipeline=None,
noise_scheduler=None,
**kwargs,
):
super().__init__(
device,
model_config,
dtype,
custom_pipeline,
noise_scheduler,
**kwargs,
)
# use the new format on this new model by default
self.use_old_lokr_format = False
def load_te(self):
if self.flux2_klein_te_path is None:
raise ValueError("flux2_klein_te_path must be set for Flux2KleinModel")
dtype = self.torch_dtype
self.print_and_status_update("Loading Qwen3")
text_encoder: Qwen3ForCausalLM = Qwen3ForCausalLM.from_pretrained(
self.flux2_klein_te_path,
torch_dtype=dtype,
)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing Qwen3")
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
freeze(text_encoder)
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_text_encoder_percent > 0
):
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
)
tokenizer = Qwen2Tokenizer.from_pretrained(self.flux2_klein_te_path)
return text_encoder, tokenizer
class Flux2Klein4BModel(Flux2KleinModel):
arch = "flux2_klein_4b"
flux2_klein_te_path: str = "Qwen/Qwen3-4B"
flux2_te_filename: str = "flux-2-klein-base-4b.safetensors"
def get_flux2_params(self):
return Klein4BParams()
def get_base_model_version(self):
return "flux2_klein_4b"
class Flux2Klein9BModel(Flux2KleinModel):
arch = "flux2_klein_9b"
flux2_klein_te_path: str = "Qwen/Qwen3-8B"
flux2_te_filename: str = "flux-2-klein-base-9b.safetensors"
def get_flux2_params(self):
return Klein9BParams()
def get_base_model_version(self):
return "flux2_klein_9b"

View File

@@ -55,6 +55,10 @@ HF_TOKEN = os.getenv("HF_TOKEN", None)
class Flux2Model(BaseModel): class Flux2Model(BaseModel):
arch = "flux2" arch = "flux2"
flux2_te_type: str = "mistral" # "mistral" or "qwen"
flux2_vae_path: str = None
flux2_te_filename: str = FLUX2_TRANSFORMER_FILENAME
flux2_is_guidance_distilled: bool = True
def __init__( def __init__(
self, self,
@@ -84,6 +88,42 @@ class Flux2Model(BaseModel):
def get_bucket_divisibility(self): def get_bucket_divisibility(self):
return 16 return 16
def get_flux2_params(self):
return Flux2Params()
def load_te(self):
dtype = self.torch_dtype
self.print_and_status_update("Loading Mistral")
text_encoder: Mistral3ForConditionalGeneration = (
Mistral3ForConditionalGeneration.from_pretrained(
MISTRAL_PATH,
torch_dtype=dtype,
)
)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing Mistral")
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
freeze(text_encoder)
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_text_encoder_percent > 0
):
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
)
tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH)
return text_encoder, tokenizer
def load_model(self): def load_model(self):
dtype = self.torch_dtype dtype = self.torch_dtype
self.print_and_status_update("Loading Flux2 model") self.print_and_status_update("Loading Flux2 model")
@@ -93,19 +133,17 @@ class Flux2Model(BaseModel):
self.print_and_status_update("Loading transformer") self.print_and_status_update("Loading transformer")
with torch.device("meta"): with torch.device("meta"):
transformer = Flux2(Flux2Params()) transformer = Flux2(self.get_flux2_params())
# use local path if provided # use local path if provided
if os.path.exists(os.path.join(transformer_path, FLUX2_TRANSFORMER_FILENAME)): if os.path.exists(os.path.join(transformer_path, self.flux2_te_filename)):
transformer_path = os.path.join( transformer_path = os.path.join(transformer_path, self.flux2_te_filename)
transformer_path, FLUX2_TRANSFORMER_FILENAME
)
if not os.path.exists(transformer_path): if not os.path.exists(transformer_path):
# assume it is from the hub # assume it is from the hub
transformer_path = huggingface_hub.hf_hub_download( transformer_path = huggingface_hub.hf_hub_download(
repo_id=model_path, repo_id=model_path,
filename=FLUX2_TRANSFORMER_FILENAME, filename=self.flux2_te_filename,
token=HF_TOKEN, token=HF_TOKEN,
) )
@@ -143,35 +181,7 @@ class Flux2Model(BaseModel):
self.print_and_status_update("Moving transformer to CPU") self.print_and_status_update("Moving transformer to CPU")
transformer.to("cpu") transformer.to("cpu")
self.print_and_status_update("Loading Mistral") text_encoder, tokenizer = self.load_te()
text_encoder: Mistral3ForConditionalGeneration = (
Mistral3ForConditionalGeneration.from_pretrained(
MISTRAL_PATH,
torch_dtype=dtype,
)
)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing Mistral")
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
freeze(text_encoder)
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_text_encoder_percent > 0
):
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
)
tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH)
self.print_and_status_update("Loading VAE") self.print_and_status_update("Loading VAE")
vae_path = self.model_config.vae_path vae_path = self.model_config.vae_path
@@ -179,10 +189,14 @@ class Flux2Model(BaseModel):
if os.path.exists(os.path.join(model_path, FLUX2_VAE_FILENAME)): if os.path.exists(os.path.join(model_path, FLUX2_VAE_FILENAME)):
vae_path = os.path.join(model_path, FLUX2_VAE_FILENAME) vae_path = os.path.join(model_path, FLUX2_VAE_FILENAME)
if vae_path is None:
vae_path = self.flux2_vae_path
if vae_path is None or not os.path.exists(vae_path): if vae_path is None or not os.path.exists(vae_path):
p = vae_path if vae_path is not None else model_path
# assume it is from the hub # assume it is from the hub
vae_path = huggingface_hub.hf_hub_download( vae_path = huggingface_hub.hf_hub_download(
repo_id=model_path, repo_id=p,
filename=FLUX2_VAE_FILENAME, filename=FLUX2_VAE_FILENAME,
token=HF_TOKEN, token=HF_TOKEN,
) )
@@ -207,6 +221,8 @@ class Flux2Model(BaseModel):
tokenizer=tokenizer, tokenizer=tokenizer,
vae=vae, vae=vae,
transformer=None, transformer=None,
text_encoder_type=self.flux2_te_type,
is_guidance_distilled=self.flux2_is_guidance_distilled,
) )
# for quantization, it works best to do these after making the pipe # for quantization, it works best to do these after making the pipe
pipe.transformer = transformer pipe.transformer = transformer
@@ -241,6 +257,8 @@ class Flux2Model(BaseModel):
tokenizer=self.tokenizer[0], tokenizer=self.tokenizer[0],
vae=unwrap_model(self.vae), vae=unwrap_model(self.vae),
transformer=unwrap_model(self.transformer), transformer=unwrap_model(self.transformer),
text_encoder_type=self.flux2_te_type,
is_guidance_distilled=self.flux2_is_guidance_distilled,
) )
pipeline = pipeline.to(self.device_torch) pipeline = pipeline.to(self.device_torch)
@@ -281,6 +299,9 @@ class Flux2Model(BaseModel):
control_img = control_img.convert("RGB") control_img = control_img.convert("RGB")
control_img_list.append(control_img) control_img_list.append(control_img)
if not self.flux2_is_guidance_distilled:
extra["negative_prompt_embeds"] = unconditional_embeds.text_embeds
img = pipeline( img = pipeline(
prompt_embeds=conditional_embeds.text_embeds, prompt_embeds=conditional_embeds.text_embeds,
height=gen_config.height, height=gen_config.height,

View File

@@ -17,6 +17,35 @@ class Flux2Params:
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32]) axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
theta: int = 2000 theta: int = 2000
mlp_ratio: float = 3.0 mlp_ratio: float = 3.0
use_guidance_embed: bool = True
@dataclass
class Klein9BParams:
in_channels: int = 128
context_in_dim: int = 12288
hidden_size: int = 4096
num_heads: int = 32
depth: int = 8
depth_single_blocks: int = 24
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
theta: int = 2000
mlp_ratio: float = 3.0
use_guidance_embed: bool = False
@dataclass
class Klein4BParams:
in_channels: int = 128
context_in_dim: int = 7680
hidden_size: int = 3072
num_heads: int = 24
depth: int = 5
depth_single_blocks: int = 20
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
theta: int = 2000
mlp_ratio: float = 3.0
use_guidance_embed: bool = False
class FakeConfig: class FakeConfig:
@@ -50,11 +79,14 @@ class Flux2(nn.Module):
self.time_in = MLPEmbedder( self.time_in = MLPEmbedder(
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
) )
self.guidance_in = MLPEmbedder(
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False) self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False)
self.use_guidance_embed = params.use_guidance_embed
if self.use_guidance_embed:
self.guidance_in = MLPEmbedder(
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
)
self.double_blocks = nn.ModuleList( self.double_blocks = nn.ModuleList(
[ [
DoubleStreamBlock( DoubleStreamBlock(
@@ -116,14 +148,15 @@ class Flux2(nn.Module):
timesteps: Tensor, timesteps: Tensor,
ctx: Tensor, ctx: Tensor,
ctx_ids: Tensor, ctx_ids: Tensor,
guidance: Tensor, guidance: Tensor | None,
): ):
num_txt_tokens = ctx.shape[1] num_txt_tokens = ctx.shape[1]
timestep_emb = timestep_embedding(timesteps, 256) timestep_emb = timestep_embedding(timesteps, 256)
vec = self.time_in(timestep_emb) vec = self.time_in(timestep_emb)
guidance_emb = timestep_embedding(guidance, 256) if self.use_guidance_embed:
vec = vec + self.guidance_in(guidance_emb) guidance_emb = timestep_embedding(guidance, 256)
vec = vec + self.guidance_in(guidance_emb)
double_block_mod_img = self.double_stream_modulation_img(vec) double_block_mod_img = self.double_stream_modulation_img(vec)
double_block_mod_txt = self.double_stream_modulation_txt(vec) double_block_mod_txt = self.double_stream_modulation_txt(vec)

View File

@@ -4,8 +4,6 @@ import numpy as np
import torch import torch
import PIL.Image import PIL.Image
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import ( from diffusers.utils import (
@@ -13,14 +11,10 @@ from diffusers.utils import (
) )
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import BaseOutput from diffusers.utils import BaseOutput
from .autoencoder import AutoEncoder from .autoencoder import AutoEncoder
from .model import Flux2 from .model import Flux2
from einops import rearrange from einops import rearrange
from transformers import AutoProcessor, Mistral3ForConditionalGeneration from transformers import AutoProcessor, Mistral3ForConditionalGeneration
from .sampling import ( from .sampling import (
@@ -41,7 +35,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
attribution and actions without speculation.""" attribution and actions without speculation."""
OUTPUT_LAYERS = [10, 20, 30] OUTPUT_LAYERS_MISTRAL = [10, 20, 30]
OUTPUT_LAYERS_QWEN3 = [9, 18, 27]
MAX_LENGTH = 512 MAX_LENGTH = 512
@@ -56,6 +51,8 @@ class Flux2Pipeline(DiffusionPipeline):
text_encoder: Mistral3ForConditionalGeneration, text_encoder: Mistral3ForConditionalGeneration,
tokenizer: AutoProcessor, tokenizer: AutoProcessor,
transformer: Flux2, transformer: Flux2,
text_encoder_type: str = "mistral", # "mistral" or "qwen"
is_guidance_distilled: bool = False,
): ):
super().__init__() super().__init__()
@@ -70,6 +67,8 @@ class Flux2Pipeline(DiffusionPipeline):
self.num_channels_latents = 128 self.num_channels_latents = 128
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = 64 self.default_sample_size = 64
self.text_encoder_type = text_encoder_type
self.is_guidance_distilled = is_guidance_distilled
def format_input( def format_input(
self, self,
@@ -138,12 +137,66 @@ class Flux2Pipeline(DiffusionPipeline):
use_cache=False, use_cache=False,
) )
out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1) out = torch.stack(
[output.hidden_states[k] for k in OUTPUT_LAYERS_MISTRAL], dim=1
)
prompt_embeds = rearrange(out, "b c l d -> b l (c d)") prompt_embeds = rearrange(out, "b c l d -> b l (c d)")
# they don't return attention mask, so we create it here # they don't return attention mask, so we create it here
return prompt_embeds, None return prompt_embeds, None
def _get_qwen_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 512,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
if not isinstance(prompt, list):
prompt = [prompt]
all_input_ids = []
all_attention_masks = []
for p in prompt:
messages = [{"role": "user", "content": p}]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
model_inputs = self.tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
all_input_ids.append(model_inputs["input_ids"])
all_attention_masks.append(model_inputs["attention_mask"])
input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
output = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS_QWEN3], dim=1)
prompt_embeds = rearrange(out, "b c l d -> b l (c d)")
# they dont use attention mask
return prompt_embeds, None
def encode_prompt( def encode_prompt(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
@@ -159,9 +212,18 @@ class Flux2Pipeline(DiffusionPipeline):
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_mistral_prompt_embeds( if self.text_encoder_type == "mistral":
prompt, device, max_sequence_length=max_sequence_length prompt_embeds, prompt_embeds_mask = self._get_mistral_prompt_embeds(
) prompt, device, max_sequence_length=max_sequence_length
)
elif self.text_encoder_type == "qwen":
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
prompt, device, max_sequence_length=max_sequence_length
)
else:
raise ValueError(
f"Unsupported text_encoder_type: {self.text_encoder_type}"
)
_, seq_len, _ = prompt_embeds.shape _, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -220,6 +282,7 @@ class Flux2Pipeline(DiffusionPipeline):
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
@@ -229,6 +292,8 @@ class Flux2Pipeline(DiffusionPipeline):
latents: Optional[torch.Tensor] = None, latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None, prompt_embeds_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
max_sequence_length: int = 512, max_sequence_length: int = 512,
@@ -236,6 +301,11 @@ class Flux2Pipeline(DiffusionPipeline):
): ):
height = height or self.default_sample_size * self.vae_scale_factor height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor
do_guidance = (
guidance_scale is not None
and guidance_scale > 1.0
and not self.is_guidance_distilled
)
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._current_timestep = None self._current_timestep = None
@@ -263,6 +333,19 @@ class Flux2Pipeline(DiffusionPipeline):
) )
txt, txt_ids = batched_prc_txt(prompt_embeds) txt, txt_ids = batched_prc_txt(prompt_embeds)
neg_txt, neg_txt_ids = None, None
if do_guidance:
negative_prompt_embeds, _ = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
neg_txt, neg_txt_ids = batched_prc_txt(negative_prompt_embeds)
# 4. Prepare latent variables\ # 4. Prepare latent variables\
latents = self.prepare_latents( latents = self.prepare_latents(
@@ -329,6 +412,17 @@ class Flux2Pipeline(DiffusionPipeline):
guidance=guidance_vec, guidance=guidance_vec,
) )
if do_guidance:
pred_uncond = self.transformer(
x=img_input,
x_ids=img_input_ids,
timesteps=t_vec,
ctx=neg_txt,
ctx_ids=neg_txt_ids,
guidance=guidance_vec,
)
pred = pred_uncond + guidance_scale * (pred - pred_uncond)
if img_cond_seq is not None: if img_cond_seq is not None:
pred = pred[:, : packed_latents.shape[1]] pred = pred[:, : packed_latents.shape[1]]

View File

@@ -630,6 +630,68 @@ export const modelArchs: ModelArch[] = [
disableSections: ['network.conv'], disableSections: ['network.conv'],
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch', 'datasets.do_i2v'], additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch', 'datasets.do_i2v'],
}, },
{
name: 'flux2_klein_4b',
label: 'FLUX.2-klein-base-4B',
group: 'image',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.2-klein-base-4B', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.low_vram': [true, false],
'config.process[0].train.unload_text_encoder': [false, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
'config.process[0].model.model_kwargs': [
{
match_target_res: false,
},
{},
],
},
disableSections: ['network.conv'],
additionalSections: [
'datasets.multi_control_paths',
'sample.multi_ctrl_imgs',
'model.low_vram',
'model.layer_offloading',
'model.qie.match_target_res',
],
},
{
name: 'flux2_klein_9b',
label: 'FLUX.2-klein-base-9B',
group: 'image',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.2-klein-base-9B', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.low_vram': [true, false],
'config.process[0].train.unload_text_encoder': [false, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
'config.process[0].model.model_kwargs': [
{
match_target_res: false,
},
{},
],
},
disableSections: ['network.conv'],
additionalSections: [
'datasets.multi_control_paths',
'sample.multi_ctrl_imgs',
'model.low_vram',
'model.layer_offloading',
'model.qie.match_target_res',
],
},
].sort((a, b) => { ].sort((a, b) => {
// Sort by label, case-insensitive // Sort by label, case-insensitive
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' }); return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' });

View File

@@ -1 +1 @@
VERSION = "0.7.19" VERSION = "0.7.20"