diff --git a/extensions-builtin/sd_forge_svd/scripts/forge_svd.py b/extensions-builtin/sd_forge_svd/scripts/forge_svd.py index 4a6ee121..a675d65d 100644 --- a/extensions-builtin/sd_forge_svd/scripts/forge_svd.py +++ b/extensions-builtin/sd_forge_svd/scripts/forge_svd.py @@ -8,7 +8,7 @@ from modules.paths import models_path from modules.ui_common import ToolButton, refresh_symbol from modules import shared -from modules_forge.forge_util import numpy_to_pytorch, pytorch_to_numpy +from modules_forge.forge_util import numpy_to_pytorch, pytorch_to_numpy, write_images_to_mp4 from ldm_patched.modules.sd import load_checkpoint_guess_config from ldm_patched.contrib.external_video_model import VideoLinearCFGGuidance, SVD_img2vid_Conditioning from ldm_patched.contrib.external import KSampler, VAEDecode @@ -50,7 +50,10 @@ def predict(filename, width, height, video_frames, motion_bucket_id, fps, augmen negative, latent_image, sampling_denoise)[0] output_pixels = opVAEDecode.decode(vae, output_latent)[0] outputs = pytorch_to_numpy(output_pixels) - return outputs + + video_filename = write_images_to_mp4(outputs, fps=fps) + + return outputs, video_filename def on_ui_tabs(): @@ -98,10 +101,11 @@ def on_ui_tabs(): sampling_denoise, guidance_min_cfg, input_image] with gr.Column(): + output_video = gr.Video(autoplay=True) output_gallery = gr.Gallery(label='Gallery', show_label=False, object_fit='contain', visible=True, height=1024, columns=4) - generate_button.click(predict, inputs=ctrls, outputs=[output_gallery]) + generate_button.click(predict, inputs=ctrls, outputs=[output_gallery, output_video]) return [(svd_block, "SVD", "svd")] diff --git a/modules_forge/forge_util.py b/modules_forge/forge_util.py index 709582dd..30fcc373 100644 --- a/modules_forge/forge_util.py +++ b/modules_forge/forge_util.py @@ -1,9 +1,20 @@ import torch import numpy as np +import os +import time +import random +import string from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn +def generate_random_filename(extension=".txt"): + timestamp = time.strftime("%Y%m%d-%H%M%S") + random_string = ''.join(random.choices(string.ascii_lowercase + string.digits, k=5)) + filename = f"{timestamp}-{random_string}{extension}" + return filename + + def cond_from_a1111_to_patched_ldm(cond): if isinstance(cond, torch.Tensor): result = dict( @@ -43,3 +54,44 @@ def numpy_to_pytorch(x): y = np.ascontiguousarray(y.copy()) y = torch.from_numpy(y).float() return y + + +def write_images_to_mp4(frame_list: list, filename=None, fps=6): + from modules.paths_internal import default_output_dir + + video_folder = os.path.join(default_output_dir, 'svd') + os.makedirs(video_folder, exist_ok=True) + + if filename is None: + filename = generate_random_filename('.mp4') + + full_path = os.path.join(video_folder, filename) + + try: + import av + except ImportError: + from launch import run_pip + run_pip( + "install imageio[pyav]", + "imageio[pyav]", + ) + import av + + options = { + "crf": str(23) + } + + output = av.open(full_path, "w") + + stream = output.add_stream('libx264', fps, options=options) + stream.width = frame_list[0].shape[1] + stream.height = frame_list[0].shape[0] + for img in frame_list: + frame = av.VideoFrame.from_ndarray(img) + packet = stream.encode(frame) + output.mux(packet) + packet = stream.encode(None) + output.mux(packet) + output.close() + + return full_path