Update forge_svd.py

This commit is contained in:
lllyasviel
2024-01-25 23:09:10 -08:00
parent a3ec20b03f
commit 8c10ec65f0

View File

@@ -6,18 +6,25 @@ from modules import scripts, script_callbacks
from modules.paths import models_path
from modules.ui_common import ToolButton, refresh_symbol
from modules import shared
from modules_forge.gradio_compile import gradio_compile
from ldm_patched.contrib.external_video_model import ImageOnlyCheckpointLoader, VideoLinearCFGGuidance, SVD_img2vid_Conditioning
from ldm_patched.modules.sd import load_checkpoint_guess_config
from ldm_patched.contrib.external_video_model import VideoLinearCFGGuidance, SVD_img2vid_Conditioning
from ldm_patched.contrib.external import KSampler, VAEDecode
# from modules_forge.gradio_compile import gradio_compile
# ps = []
# ps += gradio_compile(SVD_img2vid_Conditioning.INPUT_TYPES(), prefix='')
# ps += gradio_compile(KSampler.INPUT_TYPES(), prefix='sampling')
# ps += gradio_compile(VideoLinearCFGGuidance.INPUT_TYPES(), prefix='guidance')
# print(', '.join(ps))
opVideoLinearCFGGuidance = VideoLinearCFGGuidance()
opSVD_img2vid_Conditioning = SVD_img2vid_Conditioning()
opKSampler = KSampler()
opVAEDecode = VAEDecode()
svd_root = os.path.join(models_path, 'svd')
os.makedirs(svd_root, exist_ok=True)
svd_filenames = []
@@ -32,6 +39,9 @@ def update_svd_filenames():
def predict(filename, width, height, video_frames, motion_bucket_id, fps, augmentation_level,
sampling_seed, sampling_steps, sampling_cfg, sampling_sampler_name, sampling_scheduler,
sampling_denoise, guidance_min_cfg):
filename = os.path.join(svd_root, filename)
model, _, vae, clip_vision = load_checkpoint_guess_config(filename, output_vae=True, output_clip=False, output_clipvision=True)
a = 0
return