Add support for FLUX.2 klein base models

This commit is contained in:
Jaret Burkett
2026-01-17 17:46:25 -07:00
parent 0efed794b4
commit a6da9e37ac
8 changed files with 361 additions and 56 deletions

View File

@@ -55,6 +55,10 @@ HF_TOKEN = os.getenv("HF_TOKEN", None)
class Flux2Model(BaseModel):
arch = "flux2"
flux2_te_type: str = "mistral" # "mistral" or "qwen"
flux2_vae_path: str = None
flux2_te_filename: str = FLUX2_TRANSFORMER_FILENAME
flux2_is_guidance_distilled: bool = True
def __init__(
self,
@@ -84,6 +88,42 @@ class Flux2Model(BaseModel):
def get_bucket_divisibility(self):
return 16
def get_flux2_params(self):
return Flux2Params()
def load_te(self):
dtype = self.torch_dtype
self.print_and_status_update("Loading Mistral")
text_encoder: Mistral3ForConditionalGeneration = (
Mistral3ForConditionalGeneration.from_pretrained(
MISTRAL_PATH,
torch_dtype=dtype,
)
)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing Mistral")
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
freeze(text_encoder)
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_text_encoder_percent > 0
):
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
)
tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH)
return text_encoder, tokenizer
def load_model(self):
dtype = self.torch_dtype
self.print_and_status_update("Loading Flux2 model")
@@ -93,19 +133,17 @@ class Flux2Model(BaseModel):
self.print_and_status_update("Loading transformer")
with torch.device("meta"):
transformer = Flux2(Flux2Params())
transformer = Flux2(self.get_flux2_params())
# use local path if provided
if os.path.exists(os.path.join(transformer_path, FLUX2_TRANSFORMER_FILENAME)):
transformer_path = os.path.join(
transformer_path, FLUX2_TRANSFORMER_FILENAME
)
if os.path.exists(os.path.join(transformer_path, self.flux2_te_filename)):
transformer_path = os.path.join(transformer_path, self.flux2_te_filename)
if not os.path.exists(transformer_path):
# assume it is from the hub
transformer_path = huggingface_hub.hf_hub_download(
repo_id=model_path,
filename=FLUX2_TRANSFORMER_FILENAME,
filename=self.flux2_te_filename,
token=HF_TOKEN,
)
@@ -143,35 +181,7 @@ class Flux2Model(BaseModel):
self.print_and_status_update("Moving transformer to CPU")
transformer.to("cpu")
self.print_and_status_update("Loading Mistral")
text_encoder: Mistral3ForConditionalGeneration = (
Mistral3ForConditionalGeneration.from_pretrained(
MISTRAL_PATH,
torch_dtype=dtype,
)
)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing Mistral")
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
freeze(text_encoder)
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_text_encoder_percent > 0
):
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
)
tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH)
text_encoder, tokenizer = self.load_te()
self.print_and_status_update("Loading VAE")
vae_path = self.model_config.vae_path
@@ -179,10 +189,14 @@ class Flux2Model(BaseModel):
if os.path.exists(os.path.join(model_path, FLUX2_VAE_FILENAME)):
vae_path = os.path.join(model_path, FLUX2_VAE_FILENAME)
if vae_path is None:
vae_path = self.flux2_vae_path
if vae_path is None or not os.path.exists(vae_path):
p = vae_path if vae_path is not None else model_path
# assume it is from the hub
vae_path = huggingface_hub.hf_hub_download(
repo_id=model_path,
repo_id=p,
filename=FLUX2_VAE_FILENAME,
token=HF_TOKEN,
)
@@ -207,6 +221,8 @@ class Flux2Model(BaseModel):
tokenizer=tokenizer,
vae=vae,
transformer=None,
text_encoder_type=self.flux2_te_type,
is_guidance_distilled=self.flux2_is_guidance_distilled,
)
# for quantization, it works best to do these after making the pipe
pipe.transformer = transformer
@@ -241,6 +257,8 @@ class Flux2Model(BaseModel):
tokenizer=self.tokenizer[0],
vae=unwrap_model(self.vae),
transformer=unwrap_model(self.transformer),
text_encoder_type=self.flux2_te_type,
is_guidance_distilled=self.flux2_is_guidance_distilled,
)
pipeline = pipeline.to(self.device_torch)
@@ -281,6 +299,9 @@ class Flux2Model(BaseModel):
control_img = control_img.convert("RGB")
control_img_list.append(control_img)
if not self.flux2_is_guidance_distilled:
extra["negative_prompt_embeds"] = unconditional_embeds.text_embeds
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds,
height=gen_config.height,