mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 01:39:20 +00:00
Adjustments to loading of flux. Added a feedback to ema
This commit is contained in:
@@ -118,6 +118,7 @@ class StableDiffusion:
|
||||
dtype='fp16',
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
quantize_device=None,
|
||||
):
|
||||
self.custom_pipeline = custom_pipeline
|
||||
self.device = device
|
||||
@@ -171,6 +172,8 @@ class StableDiffusion:
|
||||
if self.is_flux or self.is_v3 or self.is_auraflow:
|
||||
self.is_flow_matching = True
|
||||
|
||||
self.quantize_device = quantize_device if quantize_device is not None else self.device
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
return
|
||||
@@ -454,10 +457,6 @@ class StableDiffusion:
|
||||
elif self.model_config.is_flux:
|
||||
print("Loading Flux model")
|
||||
base_model_path = "black-forest-labs/FLUX.1-schnell"
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
print("Loading vae")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
print("Loading transformer")
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
@@ -472,19 +471,19 @@ class StableDiffusion:
|
||||
# low_cpu_mem_usage=False,
|
||||
# device_map=None
|
||||
)
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
transformer.to(torch.device(self.quantize_device), dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
# need the pipe to do this unfortunately for now
|
||||
# we have to fuse in the weights before quantizing
|
||||
pipe: FluxPipeline = FluxPipeline(
|
||||
scheduler=scheduler,
|
||||
scheduler=None,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
text_encoder_2=None,
|
||||
tokenizer_2=None,
|
||||
vae=vae,
|
||||
vae=None,
|
||||
transformer=transformer,
|
||||
)
|
||||
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
|
||||
@@ -496,6 +495,15 @@ class StableDiffusion:
|
||||
print("Quantizing transformer")
|
||||
quantize(transformer, weights=qfloat8)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
else:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
print("Loading vae")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
|
||||
print("Loading t5")
|
||||
|
||||
Reference in New Issue
Block a user