mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fix issue with wan22 14b that woudl load both transformers temporarily resulting in oom on 24GB.
This commit is contained in:
@@ -526,6 +526,44 @@ class Wan2214bModel(Wan21):
|
||||
combined_dict = new_dict
|
||||
|
||||
return combined_dict
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
# reactivate progress bar since this is slooooow
|
||||
pipeline.set_progress_bar_config(disable=False)
|
||||
# todo, figure out how to do video
|
||||
output = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds.to(
|
||||
self.device_torch, dtype=self.torch_dtype),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
|
||||
self.device_torch, dtype=self.torch_dtype),
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
num_frames=gen_config.num_frames,
|
||||
generator=generator,
|
||||
return_dict=False,
|
||||
output_type="pil",
|
||||
**extra
|
||||
)[0]
|
||||
|
||||
# shape = [1, frames, channels, height, width]
|
||||
batch_item = output[0] # list of pil images
|
||||
if gen_config.num_frames > 1:
|
||||
return batch_item # return the frames.
|
||||
else:
|
||||
# get just the first image
|
||||
img = batch_item[0]
|
||||
return img
|
||||
|
||||
def get_model_to_train(self):
|
||||
# todo, loras wont load right unless they have the transformer_1 or transformer_2 in the key.
|
||||
|
||||
Reference in New Issue
Block a user