From 8c10ec65f0ab4f997c6c8db133b6c3385ee87b80 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 23:09:10 -0800 Subject: [PATCH] Update forge_svd.py --- .../sd_forge_svd/scripts/forge_svd.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/sd_forge_svd/scripts/forge_svd.py b/extensions-builtin/sd_forge_svd/scripts/forge_svd.py index 548cd0e7..6493ce54 100644 --- a/extensions-builtin/sd_forge_svd/scripts/forge_svd.py +++ b/extensions-builtin/sd_forge_svd/scripts/forge_svd.py @@ -6,18 +6,25 @@ from modules import scripts, script_callbacks from modules.paths import models_path from modules.ui_common import ToolButton, refresh_symbol from modules import shared -from modules_forge.gradio_compile import gradio_compile -from ldm_patched.contrib.external_video_model import ImageOnlyCheckpointLoader, VideoLinearCFGGuidance, SVD_img2vid_Conditioning +from ldm_patched.modules.sd import load_checkpoint_guess_config +from ldm_patched.contrib.external_video_model import VideoLinearCFGGuidance, SVD_img2vid_Conditioning from ldm_patched.contrib.external import KSampler, VAEDecode +# from modules_forge.gradio_compile import gradio_compile # ps = [] # ps += gradio_compile(SVD_img2vid_Conditioning.INPUT_TYPES(), prefix='') # ps += gradio_compile(KSampler.INPUT_TYPES(), prefix='sampling') # ps += gradio_compile(VideoLinearCFGGuidance.INPUT_TYPES(), prefix='guidance') # print(', '.join(ps)) + +opVideoLinearCFGGuidance = VideoLinearCFGGuidance() +opSVD_img2vid_Conditioning = SVD_img2vid_Conditioning() +opKSampler = KSampler() +opVAEDecode = VAEDecode() + svd_root = os.path.join(models_path, 'svd') os.makedirs(svd_root, exist_ok=True) svd_filenames = [] @@ -32,6 +39,9 @@ def update_svd_filenames(): def predict(filename, width, height, video_frames, motion_bucket_id, fps, augmentation_level, sampling_seed, sampling_steps, sampling_cfg, sampling_sampler_name, sampling_scheduler, sampling_denoise, guidance_min_cfg): + filename = os.path.join(svd_root, filename) + model, _, vae, clip_vision = load_checkpoint_guess_config(filename, output_vae=True, output_clip=False, output_clipvision=True) + a = 0 return