Fix issue with wan22 14b that woudl load both transformers temporarily resulting in oom on 24GB.

This commit is contained in:
Jaret Burkett
2025-08-28 13:06:31 -06:00
parent e3349414fd
commit 056711d4ed

View File

@@ -526,6 +526,44 @@ class Wan2214bModel(Wan21):
combined_dict = new_dict
return combined_dict
def generate_single_image(
self,
pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
# reactivate progress bar since this is slooooow
pipeline.set_progress_bar_config(disable=False)
# todo, figure out how to do video
output = pipeline(
prompt_embeds=conditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype),
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype),
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
num_frames=gen_config.num_frames,
generator=generator,
return_dict=False,
output_type="pil",
**extra
)[0]
# shape = [1, frames, channels, height, width]
batch_item = output[0] # list of pil images
if gen_config.num_frames > 1:
return batch_item # return the frames.
else:
# get just the first image
img = batch_item[0]
return img
def get_model_to_train(self):
# todo, loras wont load right unless they have the transformer_1 or transformer_2 in the key.