mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Add support for FLUX.2 klein base models
This commit is contained in:
@@ -55,6 +55,10 @@ HF_TOKEN = os.getenv("HF_TOKEN", None)
|
||||
|
||||
class Flux2Model(BaseModel):
|
||||
arch = "flux2"
|
||||
flux2_te_type: str = "mistral" # "mistral" or "qwen"
|
||||
flux2_vae_path: str = None
|
||||
flux2_te_filename: str = FLUX2_TRANSFORMER_FILENAME
|
||||
flux2_is_guidance_distilled: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -84,6 +88,42 @@ class Flux2Model(BaseModel):
|
||||
def get_bucket_divisibility(self):
|
||||
return 16
|
||||
|
||||
def get_flux2_params(self):
|
||||
return Flux2Params()
|
||||
|
||||
def load_te(self):
|
||||
dtype = self.torch_dtype
|
||||
self.print_and_status_update("Loading Mistral")
|
||||
|
||||
text_encoder: Mistral3ForConditionalGeneration = (
|
||||
Mistral3ForConditionalGeneration.from_pretrained(
|
||||
MISTRAL_PATH,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
)
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing Mistral")
|
||||
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
if (
|
||||
self.model_config.layer_offloading
|
||||
and self.model_config.layer_offloading_text_encoder_percent > 0
|
||||
):
|
||||
MemoryManager.attach(
|
||||
text_encoder,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
|
||||
)
|
||||
|
||||
tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH)
|
||||
return text_encoder, tokenizer
|
||||
|
||||
def load_model(self):
|
||||
dtype = self.torch_dtype
|
||||
self.print_and_status_update("Loading Flux2 model")
|
||||
@@ -93,19 +133,17 @@ class Flux2Model(BaseModel):
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
with torch.device("meta"):
|
||||
transformer = Flux2(Flux2Params())
|
||||
transformer = Flux2(self.get_flux2_params())
|
||||
|
||||
# use local path if provided
|
||||
if os.path.exists(os.path.join(transformer_path, FLUX2_TRANSFORMER_FILENAME)):
|
||||
transformer_path = os.path.join(
|
||||
transformer_path, FLUX2_TRANSFORMER_FILENAME
|
||||
)
|
||||
if os.path.exists(os.path.join(transformer_path, self.flux2_te_filename)):
|
||||
transformer_path = os.path.join(transformer_path, self.flux2_te_filename)
|
||||
|
||||
if not os.path.exists(transformer_path):
|
||||
# assume it is from the hub
|
||||
transformer_path = huggingface_hub.hf_hub_download(
|
||||
repo_id=model_path,
|
||||
filename=FLUX2_TRANSFORMER_FILENAME,
|
||||
filename=self.flux2_te_filename,
|
||||
token=HF_TOKEN,
|
||||
)
|
||||
|
||||
@@ -143,35 +181,7 @@ class Flux2Model(BaseModel):
|
||||
self.print_and_status_update("Moving transformer to CPU")
|
||||
transformer.to("cpu")
|
||||
|
||||
self.print_and_status_update("Loading Mistral")
|
||||
|
||||
text_encoder: Mistral3ForConditionalGeneration = (
|
||||
Mistral3ForConditionalGeneration.from_pretrained(
|
||||
MISTRAL_PATH,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
)
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing Mistral")
|
||||
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
if (
|
||||
self.model_config.layer_offloading
|
||||
and self.model_config.layer_offloading_text_encoder_percent > 0
|
||||
):
|
||||
MemoryManager.attach(
|
||||
text_encoder,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
|
||||
)
|
||||
|
||||
tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH)
|
||||
text_encoder, tokenizer = self.load_te()
|
||||
|
||||
self.print_and_status_update("Loading VAE")
|
||||
vae_path = self.model_config.vae_path
|
||||
@@ -179,10 +189,14 @@ class Flux2Model(BaseModel):
|
||||
if os.path.exists(os.path.join(model_path, FLUX2_VAE_FILENAME)):
|
||||
vae_path = os.path.join(model_path, FLUX2_VAE_FILENAME)
|
||||
|
||||
if vae_path is None:
|
||||
vae_path = self.flux2_vae_path
|
||||
|
||||
if vae_path is None or not os.path.exists(vae_path):
|
||||
p = vae_path if vae_path is not None else model_path
|
||||
# assume it is from the hub
|
||||
vae_path = huggingface_hub.hf_hub_download(
|
||||
repo_id=model_path,
|
||||
repo_id=p,
|
||||
filename=FLUX2_VAE_FILENAME,
|
||||
token=HF_TOKEN,
|
||||
)
|
||||
@@ -207,6 +221,8 @@ class Flux2Model(BaseModel):
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
transformer=None,
|
||||
text_encoder_type=self.flux2_te_type,
|
||||
is_guidance_distilled=self.flux2_is_guidance_distilled,
|
||||
)
|
||||
# for quantization, it works best to do these after making the pipe
|
||||
pipe.transformer = transformer
|
||||
@@ -241,6 +257,8 @@ class Flux2Model(BaseModel):
|
||||
tokenizer=self.tokenizer[0],
|
||||
vae=unwrap_model(self.vae),
|
||||
transformer=unwrap_model(self.transformer),
|
||||
text_encoder_type=self.flux2_te_type,
|
||||
is_guidance_distilled=self.flux2_is_guidance_distilled,
|
||||
)
|
||||
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
@@ -281,6 +299,9 @@ class Flux2Model(BaseModel):
|
||||
control_img = control_img.convert("RGB")
|
||||
control_img_list.append(control_img)
|
||||
|
||||
if not self.flux2_is_guidance_distilled:
|
||||
extra["negative_prompt_embeds"] = unconditional_embeds.text_embeds
|
||||
|
||||
img = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
height=gen_config.height,
|
||||
|
||||
Reference in New Issue
Block a user