Update forge_svd.py

This commit is contained in:
lllyasviel
2024-01-26 00:25:17 -08:00
parent 9ed196d706
commit bd269189b8

View File

@@ -39,6 +39,8 @@ def update_svd_filenames():
return svd_filenames
@torch.inference_mode()
@torch.no_grad()
def predict(filename, width, height, video_frames, motion_bucket_id, fps, augmentation_level,
sampling_seed, sampling_steps, sampling_cfg, sampling_sampler_name, sampling_scheduler,
sampling_denoise, guidance_min_cfg, input_image):
@@ -47,10 +49,6 @@ def predict(filename, width, height, video_frames, motion_bucket_id, fps, augmen
model = opVideoLinearCFGGuidance.patch(model_raw, guidance_min_cfg)[0]
init_image = numpy_to_pytorch(input_image)
positive, negative, latent_image = opSVD_img2vid_Conditioning.encode(clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level)
del model_raw, _, vae, clip_vision, model, init_image, positive, negative, latent_image
model_management.unload_all_models()
model_management.soft_empty_cache()
torch.cuda.empty_cache()
output_latent = opKSampler.sample(model, sampling_seed, sampling_steps, sampling_cfg, sampling_sampler_name, sampling_scheduler, positive, negative, latent_image, sampling_denoise)
output_pixels = opVAEDecode.decode(vae, output_latent)[0]
return