Files
ai-toolkit/extensions_built_in/diffusion_models/flux2/flux2_klein_model.py
2026-01-17 17:46:25 -07:00

93 lines
2.8 KiB
Python

from .flux2_model import Flux2Model
from transformers import Qwen3ForCausalLM, Qwen2Tokenizer
from optimum.quanto import freeze
from toolkit.util.quantize import quantize, get_qtype
from toolkit.config_modules import ModelConfig
from toolkit.memory_management.manager import MemoryManager
from toolkit.basic import flush
from .src.model import Klein9BParams, Klein4BParams
class Flux2KleinModel(Flux2Model):
flux2_klein_te_path: str = None
flux2_te_type: str = "qwen" # "mistral" or "qwen"
flux2_vae_path: str = "ai-toolkit/flux2_vae"
flux2_is_guidance_distilled: bool = False
def __init__(
self,
device,
model_config: ModelConfig,
dtype="bf16",
custom_pipeline=None,
noise_scheduler=None,
**kwargs,
):
super().__init__(
device,
model_config,
dtype,
custom_pipeline,
noise_scheduler,
**kwargs,
)
# use the new format on this new model by default
self.use_old_lokr_format = False
def load_te(self):
if self.flux2_klein_te_path is None:
raise ValueError("flux2_klein_te_path must be set for Flux2KleinModel")
dtype = self.torch_dtype
self.print_and_status_update("Loading Qwen3")
text_encoder: Qwen3ForCausalLM = Qwen3ForCausalLM.from_pretrained(
self.flux2_klein_te_path,
torch_dtype=dtype,
)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing Qwen3")
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
freeze(text_encoder)
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_text_encoder_percent > 0
):
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
)
tokenizer = Qwen2Tokenizer.from_pretrained(self.flux2_klein_te_path)
return text_encoder, tokenizer
class Flux2Klein4BModel(Flux2KleinModel):
arch = "flux2_klein_4b"
flux2_klein_te_path: str = "Qwen/Qwen3-4B"
flux2_te_filename: str = "flux-2-klein-base-4b.safetensors"
def get_flux2_params(self):
return Klein4BParams()
def get_base_model_version(self):
return "flux2_klein_4b"
class Flux2Klein9BModel(Flux2KleinModel):
arch = "flux2_klein_9b"
flux2_klein_te_path: str = "Qwen/Qwen3-8B"
flux2_te_filename: str = "flux-2-klein-base-9b.safetensors"
def get_flux2_params(self):
return Klein9BParams()
def get_base_model_version(self):
return "flux2_klein_9b"