mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
93 lines
2.8 KiB
Python
93 lines
2.8 KiB
Python
from .flux2_model import Flux2Model
|
|
from transformers import Qwen3ForCausalLM, Qwen2Tokenizer
|
|
from optimum.quanto import freeze
|
|
from toolkit.util.quantize import quantize, get_qtype
|
|
from toolkit.config_modules import ModelConfig
|
|
from toolkit.memory_management.manager import MemoryManager
|
|
from toolkit.basic import flush
|
|
from .src.model import Klein9BParams, Klein4BParams
|
|
|
|
|
|
class Flux2KleinModel(Flux2Model):
|
|
flux2_klein_te_path: str = None
|
|
flux2_te_type: str = "qwen" # "mistral" or "qwen"
|
|
flux2_vae_path: str = "ai-toolkit/flux2_vae"
|
|
flux2_is_guidance_distilled: bool = False
|
|
|
|
def __init__(
|
|
self,
|
|
device,
|
|
model_config: ModelConfig,
|
|
dtype="bf16",
|
|
custom_pipeline=None,
|
|
noise_scheduler=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
device,
|
|
model_config,
|
|
dtype,
|
|
custom_pipeline,
|
|
noise_scheduler,
|
|
**kwargs,
|
|
)
|
|
# use the new format on this new model by default
|
|
self.use_old_lokr_format = False
|
|
|
|
def load_te(self):
|
|
if self.flux2_klein_te_path is None:
|
|
raise ValueError("flux2_klein_te_path must be set for Flux2KleinModel")
|
|
dtype = self.torch_dtype
|
|
self.print_and_status_update("Loading Qwen3")
|
|
|
|
text_encoder: Qwen3ForCausalLM = Qwen3ForCausalLM.from_pretrained(
|
|
self.flux2_klein_te_path,
|
|
torch_dtype=dtype,
|
|
)
|
|
text_encoder.to(self.device_torch, dtype=dtype)
|
|
|
|
flush()
|
|
|
|
if self.model_config.quantize_te:
|
|
self.print_and_status_update("Quantizing Qwen3")
|
|
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
|
|
freeze(text_encoder)
|
|
flush()
|
|
|
|
if (
|
|
self.model_config.layer_offloading
|
|
and self.model_config.layer_offloading_text_encoder_percent > 0
|
|
):
|
|
MemoryManager.attach(
|
|
text_encoder,
|
|
self.device_torch,
|
|
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
|
|
)
|
|
|
|
tokenizer = Qwen2Tokenizer.from_pretrained(self.flux2_klein_te_path)
|
|
return text_encoder, tokenizer
|
|
|
|
|
|
class Flux2Klein4BModel(Flux2KleinModel):
|
|
arch = "flux2_klein_4b"
|
|
flux2_klein_te_path: str = "Qwen/Qwen3-4B"
|
|
flux2_te_filename: str = "flux-2-klein-base-4b.safetensors"
|
|
|
|
def get_flux2_params(self):
|
|
return Klein4BParams()
|
|
|
|
def get_base_model_version(self):
|
|
return "flux2_klein_4b"
|
|
|
|
|
|
class Flux2Klein9BModel(Flux2KleinModel):
|
|
arch = "flux2_klein_9b"
|
|
flux2_klein_te_path: str = "Qwen/Qwen3-8B"
|
|
flux2_te_filename: str = "flux-2-klein-base-9b.safetensors"
|
|
|
|
def get_flux2_params(self):
|
|
return Klein9BParams()
|
|
|
|
def get_base_model_version(self):
|
|
return "flux2_klein_9b"
|