Added a flushg during sampling to prevent spikes on low vram qwen

This commit is contained in:
Jaret Burkett
2025-08-12 12:57:18 -06:00
parent 69ee99b6e1
commit 259d68d440

View File

@@ -210,6 +210,17 @@ class QwenImageModel(BaseModel):
control_img = control_img.resize(
(gen_config.width, gen_config.height), Image.BILINEAR
)
# flush for low vram if we are doing that
flush_between_steps = self.model_config.low_vram
# Fix a bug in diffusers/torch
def callback_on_step_end(pipe, i, t, callback_kwargs):
if flush_between_steps:
flush()
latents = callback_kwargs["latents"]
return {"latents": latents}
sc = self.get_bucket_divisibility()
gen_config.width = int(gen_config.width // sc * sc)
gen_config.height = int(gen_config.height // sc * sc)
@@ -224,6 +235,7 @@ class QwenImageModel(BaseModel):
true_cfg_scale=gen_config.guidance_scale,
latents=gen_config.latents,
generator=generator,
callback_on_step_end=callback_on_step_end,
**extra
).images[0]
return img