diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py index 1dd3da9..a047516 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py @@ -210,6 +210,17 @@ class QwenImageModel(BaseModel): control_img = control_img.resize( (gen_config.width, gen_config.height), Image.BILINEAR ) + + # flush for low vram if we are doing that + flush_between_steps = self.model_config.low_vram + # Fix a bug in diffusers/torch + def callback_on_step_end(pipe, i, t, callback_kwargs): + if flush_between_steps: + flush() + latents = callback_kwargs["latents"] + + return {"latents": latents} + sc = self.get_bucket_divisibility() gen_config.width = int(gen_config.width // sc * sc) gen_config.height = int(gen_config.height // sc * sc) @@ -224,6 +235,7 @@ class QwenImageModel(BaseModel): true_cfg_scale=gen_config.guidance_scale, latents=gen_config.latents, generator=generator, + callback_on_step_end=callback_on_step_end, **extra ).images[0] return img