Update forge_svd.py

This commit is contained in:
lllyasviel
2024-01-25 23:46:26 -08:00
parent 4f142ebf95
commit 8e71497b52

View File

@@ -7,6 +7,7 @@ from modules.paths import models_path
from modules.ui_common import ToolButton, refresh_symbol
from modules import shared
from modules_forge.forge_util import numpy_to_pytorch, pytorch_to_numpy
from ldm_patched.modules.sd import load_checkpoint_guess_config
from ldm_patched.contrib.external_video_model import VideoLinearCFGGuidance, SVD_img2vid_Conditioning
from ldm_patched.contrib.external import KSampler, VAEDecode
@@ -41,8 +42,11 @@ def predict(filename, width, height, video_frames, motion_bucket_id, fps, augmen
sampling_denoise, guidance_min_cfg, input_image):
filename = os.path.join(svd_root, filename)
model_raw, _, vae, clip_vision = load_checkpoint_guess_config(filename, output_vae=True, output_clip=False, output_clipvision=True)
model = opVideoLinearCFGGuidance.patch(model_raw, guidance_min_cfg)
a = 0
model = opVideoLinearCFGGuidance.patch(model_raw, guidance_min_cfg)[0]
init_image = numpy_to_pytorch(input_image)
positive, negative, latent_image = opSVD_img2vid_Conditioning.encode(clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level)
output_latent = opKSampler.sample(model, sampling_seed, sampling_steps, sampling_cfg, sampling_sampler_name, sampling_scheduler, positive, negative, latent_image, sampling_denoise)
output_pixels = opVAEDecode.decode(vae, output_latent)[0]
return