mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-05 04:59:56 +00:00
Added a flushg during sampling to prevent spikes on low vram qwen
This commit is contained in:
@@ -210,6 +210,17 @@ class QwenImageModel(BaseModel):
|
||||
control_img = control_img.resize(
|
||||
(gen_config.width, gen_config.height), Image.BILINEAR
|
||||
)
|
||||
|
||||
# flush for low vram if we are doing that
|
||||
flush_between_steps = self.model_config.low_vram
|
||||
# Fix a bug in diffusers/torch
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if flush_between_steps:
|
||||
flush()
|
||||
latents = callback_kwargs["latents"]
|
||||
|
||||
return {"latents": latents}
|
||||
|
||||
sc = self.get_bucket_divisibility()
|
||||
gen_config.width = int(gen_config.width // sc * sc)
|
||||
gen_config.height = int(gen_config.height // sc * sc)
|
||||
@@ -224,6 +235,7 @@ class QwenImageModel(BaseModel):
|
||||
true_cfg_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
generator=generator,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
**extra
|
||||
).images[0]
|
||||
return img
|
||||
|
||||
Reference in New Issue
Block a user