mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 08:29:45 +00:00
Add support for FLUX.2 klein base models
This commit is contained in:
@@ -5,7 +5,7 @@ from .omnigen2 import OmniGen2Model
|
||||
from .flux_kontext import FluxKontextModel
|
||||
from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel
|
||||
from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel
|
||||
from .flux2 import Flux2Model
|
||||
from .flux2 import Flux2Model, Flux2Klein4BModel, Flux2Klein9BModel
|
||||
from .z_image import ZImageModel
|
||||
from .ltx2 import LTX2Model
|
||||
|
||||
@@ -27,4 +27,6 @@ AI_TOOLKIT_MODELS = [
|
||||
Flux2Model,
|
||||
ZImageModel,
|
||||
LTX2Model,
|
||||
Flux2Klein4BModel,
|
||||
Flux2Klein9BModel,
|
||||
]
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from .flux2_model import Flux2Model
|
||||
from .flux2_model import Flux2Model
|
||||
from .flux2_klein_model import Flux2Klein4BModel, Flux2Klein9BModel
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
from .flux2_model import Flux2Model
|
||||
from transformers import Qwen3ForCausalLM, Qwen2Tokenizer
|
||||
from optimum.quanto import freeze
|
||||
from toolkit.util.quantize import quantize, get_qtype
|
||||
from toolkit.config_modules import ModelConfig
|
||||
from toolkit.memory_management.manager import MemoryManager
|
||||
from toolkit.basic import flush
|
||||
from .src.model import Klein9BParams, Klein4BParams
|
||||
|
||||
|
||||
class Flux2KleinModel(Flux2Model):
|
||||
flux2_klein_te_path: str = None
|
||||
flux2_te_type: str = "qwen" # "mistral" or "qwen"
|
||||
flux2_vae_path: str = "ai-toolkit/flux2_vae"
|
||||
flux2_is_guidance_distilled: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype="bf16",
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
device,
|
||||
model_config,
|
||||
dtype,
|
||||
custom_pipeline,
|
||||
noise_scheduler,
|
||||
**kwargs,
|
||||
)
|
||||
# use the new format on this new model by default
|
||||
self.use_old_lokr_format = False
|
||||
|
||||
def load_te(self):
|
||||
if self.flux2_klein_te_path is None:
|
||||
raise ValueError("flux2_klein_te_path must be set for Flux2KleinModel")
|
||||
dtype = self.torch_dtype
|
||||
self.print_and_status_update("Loading Qwen3")
|
||||
|
||||
text_encoder: Qwen3ForCausalLM = Qwen3ForCausalLM.from_pretrained(
|
||||
self.flux2_klein_te_path,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing Qwen3")
|
||||
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
if (
|
||||
self.model_config.layer_offloading
|
||||
and self.model_config.layer_offloading_text_encoder_percent > 0
|
||||
):
|
||||
MemoryManager.attach(
|
||||
text_encoder,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
|
||||
)
|
||||
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained(self.flux2_klein_te_path)
|
||||
return text_encoder, tokenizer
|
||||
|
||||
|
||||
class Flux2Klein4BModel(Flux2KleinModel):
|
||||
arch = "flux2_klein_4b"
|
||||
flux2_klein_te_path: str = "Qwen/Qwen3-4B"
|
||||
flux2_te_filename: str = "flux-2-klein-base-4b.safetensors"
|
||||
|
||||
def get_flux2_params(self):
|
||||
return Klein4BParams()
|
||||
|
||||
def get_base_model_version(self):
|
||||
return "flux2_klein_4b"
|
||||
|
||||
|
||||
class Flux2Klein9BModel(Flux2KleinModel):
|
||||
arch = "flux2_klein_9b"
|
||||
flux2_klein_te_path: str = "Qwen/Qwen3-8B"
|
||||
flux2_te_filename: str = "flux-2-klein-base-9b.safetensors"
|
||||
|
||||
def get_flux2_params(self):
|
||||
return Klein9BParams()
|
||||
|
||||
def get_base_model_version(self):
|
||||
return "flux2_klein_9b"
|
||||
@@ -55,6 +55,10 @@ HF_TOKEN = os.getenv("HF_TOKEN", None)
|
||||
|
||||
class Flux2Model(BaseModel):
|
||||
arch = "flux2"
|
||||
flux2_te_type: str = "mistral" # "mistral" or "qwen"
|
||||
flux2_vae_path: str = None
|
||||
flux2_te_filename: str = FLUX2_TRANSFORMER_FILENAME
|
||||
flux2_is_guidance_distilled: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -84,6 +88,42 @@ class Flux2Model(BaseModel):
|
||||
def get_bucket_divisibility(self):
|
||||
return 16
|
||||
|
||||
def get_flux2_params(self):
|
||||
return Flux2Params()
|
||||
|
||||
def load_te(self):
|
||||
dtype = self.torch_dtype
|
||||
self.print_and_status_update("Loading Mistral")
|
||||
|
||||
text_encoder: Mistral3ForConditionalGeneration = (
|
||||
Mistral3ForConditionalGeneration.from_pretrained(
|
||||
MISTRAL_PATH,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
)
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing Mistral")
|
||||
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
if (
|
||||
self.model_config.layer_offloading
|
||||
and self.model_config.layer_offloading_text_encoder_percent > 0
|
||||
):
|
||||
MemoryManager.attach(
|
||||
text_encoder,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
|
||||
)
|
||||
|
||||
tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH)
|
||||
return text_encoder, tokenizer
|
||||
|
||||
def load_model(self):
|
||||
dtype = self.torch_dtype
|
||||
self.print_and_status_update("Loading Flux2 model")
|
||||
@@ -93,19 +133,17 @@ class Flux2Model(BaseModel):
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
with torch.device("meta"):
|
||||
transformer = Flux2(Flux2Params())
|
||||
transformer = Flux2(self.get_flux2_params())
|
||||
|
||||
# use local path if provided
|
||||
if os.path.exists(os.path.join(transformer_path, FLUX2_TRANSFORMER_FILENAME)):
|
||||
transformer_path = os.path.join(
|
||||
transformer_path, FLUX2_TRANSFORMER_FILENAME
|
||||
)
|
||||
if os.path.exists(os.path.join(transformer_path, self.flux2_te_filename)):
|
||||
transformer_path = os.path.join(transformer_path, self.flux2_te_filename)
|
||||
|
||||
if not os.path.exists(transformer_path):
|
||||
# assume it is from the hub
|
||||
transformer_path = huggingface_hub.hf_hub_download(
|
||||
repo_id=model_path,
|
||||
filename=FLUX2_TRANSFORMER_FILENAME,
|
||||
filename=self.flux2_te_filename,
|
||||
token=HF_TOKEN,
|
||||
)
|
||||
|
||||
@@ -143,35 +181,7 @@ class Flux2Model(BaseModel):
|
||||
self.print_and_status_update("Moving transformer to CPU")
|
||||
transformer.to("cpu")
|
||||
|
||||
self.print_and_status_update("Loading Mistral")
|
||||
|
||||
text_encoder: Mistral3ForConditionalGeneration = (
|
||||
Mistral3ForConditionalGeneration.from_pretrained(
|
||||
MISTRAL_PATH,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
)
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing Mistral")
|
||||
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
if (
|
||||
self.model_config.layer_offloading
|
||||
and self.model_config.layer_offloading_text_encoder_percent > 0
|
||||
):
|
||||
MemoryManager.attach(
|
||||
text_encoder,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
|
||||
)
|
||||
|
||||
tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH)
|
||||
text_encoder, tokenizer = self.load_te()
|
||||
|
||||
self.print_and_status_update("Loading VAE")
|
||||
vae_path = self.model_config.vae_path
|
||||
@@ -179,10 +189,14 @@ class Flux2Model(BaseModel):
|
||||
if os.path.exists(os.path.join(model_path, FLUX2_VAE_FILENAME)):
|
||||
vae_path = os.path.join(model_path, FLUX2_VAE_FILENAME)
|
||||
|
||||
if vae_path is None:
|
||||
vae_path = self.flux2_vae_path
|
||||
|
||||
if vae_path is None or not os.path.exists(vae_path):
|
||||
p = vae_path if vae_path is not None else model_path
|
||||
# assume it is from the hub
|
||||
vae_path = huggingface_hub.hf_hub_download(
|
||||
repo_id=model_path,
|
||||
repo_id=p,
|
||||
filename=FLUX2_VAE_FILENAME,
|
||||
token=HF_TOKEN,
|
||||
)
|
||||
@@ -207,6 +221,8 @@ class Flux2Model(BaseModel):
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
transformer=None,
|
||||
text_encoder_type=self.flux2_te_type,
|
||||
is_guidance_distilled=self.flux2_is_guidance_distilled,
|
||||
)
|
||||
# for quantization, it works best to do these after making the pipe
|
||||
pipe.transformer = transformer
|
||||
@@ -241,6 +257,8 @@ class Flux2Model(BaseModel):
|
||||
tokenizer=self.tokenizer[0],
|
||||
vae=unwrap_model(self.vae),
|
||||
transformer=unwrap_model(self.transformer),
|
||||
text_encoder_type=self.flux2_te_type,
|
||||
is_guidance_distilled=self.flux2_is_guidance_distilled,
|
||||
)
|
||||
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
@@ -281,6 +299,9 @@ class Flux2Model(BaseModel):
|
||||
control_img = control_img.convert("RGB")
|
||||
control_img_list.append(control_img)
|
||||
|
||||
if not self.flux2_is_guidance_distilled:
|
||||
extra["negative_prompt_embeds"] = unconditional_embeds.text_embeds
|
||||
|
||||
img = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
height=gen_config.height,
|
||||
|
||||
@@ -17,6 +17,35 @@ class Flux2Params:
|
||||
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
|
||||
theta: int = 2000
|
||||
mlp_ratio: float = 3.0
|
||||
use_guidance_embed: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class Klein9BParams:
|
||||
in_channels: int = 128
|
||||
context_in_dim: int = 12288
|
||||
hidden_size: int = 4096
|
||||
num_heads: int = 32
|
||||
depth: int = 8
|
||||
depth_single_blocks: int = 24
|
||||
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
|
||||
theta: int = 2000
|
||||
mlp_ratio: float = 3.0
|
||||
use_guidance_embed: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Klein4BParams:
|
||||
in_channels: int = 128
|
||||
context_in_dim: int = 7680
|
||||
hidden_size: int = 3072
|
||||
num_heads: int = 24
|
||||
depth: int = 5
|
||||
depth_single_blocks: int = 20
|
||||
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
|
||||
theta: int = 2000
|
||||
mlp_ratio: float = 3.0
|
||||
use_guidance_embed: bool = False
|
||||
|
||||
|
||||
class FakeConfig:
|
||||
@@ -50,11 +79,14 @@ class Flux2(nn.Module):
|
||||
self.time_in = MLPEmbedder(
|
||||
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
|
||||
)
|
||||
self.guidance_in = MLPEmbedder(
|
||||
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False)
|
||||
|
||||
self.use_guidance_embed = params.use_guidance_embed
|
||||
if self.use_guidance_embed:
|
||||
self.guidance_in = MLPEmbedder(
|
||||
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
|
||||
)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
@@ -116,14 +148,15 @@ class Flux2(nn.Module):
|
||||
timesteps: Tensor,
|
||||
ctx: Tensor,
|
||||
ctx_ids: Tensor,
|
||||
guidance: Tensor,
|
||||
guidance: Tensor | None,
|
||||
):
|
||||
num_txt_tokens = ctx.shape[1]
|
||||
|
||||
timestep_emb = timestep_embedding(timesteps, 256)
|
||||
vec = self.time_in(timestep_emb)
|
||||
guidance_emb = timestep_embedding(guidance, 256)
|
||||
vec = vec + self.guidance_in(guidance_emb)
|
||||
if self.use_guidance_embed:
|
||||
guidance_emb = timestep_embedding(guidance, 256)
|
||||
vec = vec + self.guidance_in(guidance_emb)
|
||||
|
||||
double_block_mod_img = self.double_stream_modulation_img(vec)
|
||||
double_block_mod_txt = self.double_stream_modulation_txt(vec)
|
||||
|
||||
@@ -4,8 +4,6 @@ import numpy as np
|
||||
import torch
|
||||
import PIL.Image
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
@@ -13,14 +11,10 @@ from diffusers.utils import (
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
from .autoencoder import AutoEncoder
|
||||
from .model import Flux2
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
|
||||
from .sampling import (
|
||||
@@ -41,7 +35,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
|
||||
attribution and actions without speculation."""
|
||||
OUTPUT_LAYERS = [10, 20, 30]
|
||||
OUTPUT_LAYERS_MISTRAL = [10, 20, 30]
|
||||
OUTPUT_LAYERS_QWEN3 = [9, 18, 27]
|
||||
MAX_LENGTH = 512
|
||||
|
||||
|
||||
@@ -56,6 +51,8 @@ class Flux2Pipeline(DiffusionPipeline):
|
||||
text_encoder: Mistral3ForConditionalGeneration,
|
||||
tokenizer: AutoProcessor,
|
||||
transformer: Flux2,
|
||||
text_encoder_type: str = "mistral", # "mistral" or "qwen"
|
||||
is_guidance_distilled: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -70,6 +67,8 @@ class Flux2Pipeline(DiffusionPipeline):
|
||||
self.num_channels_latents = 128
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.default_sample_size = 64
|
||||
self.text_encoder_type = text_encoder_type
|
||||
self.is_guidance_distilled = is_guidance_distilled
|
||||
|
||||
def format_input(
|
||||
self,
|
||||
@@ -138,12 +137,66 @@ class Flux2Pipeline(DiffusionPipeline):
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1)
|
||||
out = torch.stack(
|
||||
[output.hidden_states[k] for k in OUTPUT_LAYERS_MISTRAL], dim=1
|
||||
)
|
||||
prompt_embeds = rearrange(out, "b c l d -> b l (c d)")
|
||||
|
||||
# they don't return attention mask, so we create it here
|
||||
return prompt_embeds, None
|
||||
|
||||
def _get_qwen_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
if not isinstance(prompt, list):
|
||||
prompt = [prompt]
|
||||
|
||||
all_input_ids = []
|
||||
all_attention_masks = []
|
||||
|
||||
for p in prompt:
|
||||
messages = [{"role": "user", "content": p}]
|
||||
text = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
model_inputs = self.tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
all_input_ids.append(model_inputs["input_ids"])
|
||||
all_attention_masks.append(model_inputs["attention_mask"])
|
||||
|
||||
input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
output = self.text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS_QWEN3], dim=1)
|
||||
prompt_embeds = rearrange(out, "b c l d -> b l (c d)")
|
||||
|
||||
# they dont use attention mask
|
||||
return prompt_embeds, None
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
@@ -159,9 +212,18 @@ class Flux2Pipeline(DiffusionPipeline):
|
||||
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_embeds_mask = self._get_mistral_prompt_embeds(
|
||||
prompt, device, max_sequence_length=max_sequence_length
|
||||
)
|
||||
if self.text_encoder_type == "mistral":
|
||||
prompt_embeds, prompt_embeds_mask = self._get_mistral_prompt_embeds(
|
||||
prompt, device, max_sequence_length=max_sequence_length
|
||||
)
|
||||
elif self.text_encoder_type == "qwen":
|
||||
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
|
||||
prompt, device, max_sequence_length=max_sequence_length
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported text_encoder_type: {self.text_encoder_type}"
|
||||
)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -220,6 +282,7 @@ class Flux2Pipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
@@ -229,6 +292,8 @@ class Flux2Pipeline(DiffusionPipeline):
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
max_sequence_length: int = 512,
|
||||
@@ -236,6 +301,11 @@ class Flux2Pipeline(DiffusionPipeline):
|
||||
):
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
do_guidance = (
|
||||
guidance_scale is not None
|
||||
and guidance_scale > 1.0
|
||||
and not self.is_guidance_distilled
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
@@ -263,6 +333,19 @@ class Flux2Pipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
txt, txt_ids = batched_prc_txt(prompt_embeds)
|
||||
neg_txt, neg_txt_ids = None, None
|
||||
|
||||
if do_guidance:
|
||||
negative_prompt_embeds, _ = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
neg_txt, neg_txt_ids = batched_prc_txt(negative_prompt_embeds)
|
||||
|
||||
# 4. Prepare latent variables\
|
||||
latents = self.prepare_latents(
|
||||
@@ -329,6 +412,17 @@ class Flux2Pipeline(DiffusionPipeline):
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
|
||||
if do_guidance:
|
||||
pred_uncond = self.transformer(
|
||||
x=img_input,
|
||||
x_ids=img_input_ids,
|
||||
timesteps=t_vec,
|
||||
ctx=neg_txt,
|
||||
ctx_ids=neg_txt_ids,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
pred = pred_uncond + guidance_scale * (pred - pred_uncond)
|
||||
|
||||
if img_cond_seq is not None:
|
||||
pred = pred[:, : packed_latents.shape[1]]
|
||||
|
||||
|
||||
@@ -630,6 +630,68 @@ export const modelArchs: ModelArch[] = [
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch', 'datasets.do_i2v'],
|
||||
},
|
||||
{
|
||||
name: 'flux2_klein_4b',
|
||||
label: 'FLUX.2-klein-base-4B',
|
||||
group: 'image',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.2-klein-base-4B', defaultNameOrPath],
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.low_vram': [true, false],
|
||||
'config.process[0].train.unload_text_encoder': [false, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
|
||||
'config.process[0].model.model_kwargs': [
|
||||
{
|
||||
match_target_res: false,
|
||||
},
|
||||
{},
|
||||
],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: [
|
||||
'datasets.multi_control_paths',
|
||||
'sample.multi_ctrl_imgs',
|
||||
'model.low_vram',
|
||||
'model.layer_offloading',
|
||||
'model.qie.match_target_res',
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'flux2_klein_9b',
|
||||
label: 'FLUX.2-klein-base-9B',
|
||||
group: 'image',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.2-klein-base-9B', defaultNameOrPath],
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.low_vram': [true, false],
|
||||
'config.process[0].train.unload_text_encoder': [false, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
|
||||
'config.process[0].model.model_kwargs': [
|
||||
{
|
||||
match_target_res: false,
|
||||
},
|
||||
{},
|
||||
],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: [
|
||||
'datasets.multi_control_paths',
|
||||
'sample.multi_ctrl_imgs',
|
||||
'model.low_vram',
|
||||
'model.layer_offloading',
|
||||
'model.qie.match_target_res',
|
||||
],
|
||||
},
|
||||
].sort((a, b) => {
|
||||
// Sort by label, case-insensitive
|
||||
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' });
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.7.19"
|
||||
VERSION = "0.7.20"
|
||||
|
||||
Reference in New Issue
Block a user