From c9b6df7ec181a7e5c9deafa5690e194c8116dfe9 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Fri, 26 Jan 2024 08:02:02 -0800 Subject: [PATCH] Create forge_z123.py --- .../sd_forge_z123/scripts/forge_z123.py | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 extensions-builtin/sd_forge_z123/scripts/forge_z123.py diff --git a/extensions-builtin/sd_forge_z123/scripts/forge_z123.py b/extensions-builtin/sd_forge_z123/scripts/forge_z123.py new file mode 100644 index 00000000..c0abe39c --- /dev/null +++ b/extensions-builtin/sd_forge_z123/scripts/forge_z123.py @@ -0,0 +1,100 @@ +import torch +import gradio as gr +import os +import pathlib + +from modules import script_callbacks +from modules.paths import models_path +from modules.ui_common import ToolButton, refresh_symbol +from modules import shared + +from modules_forge.forge_util import numpy_to_pytorch, pytorch_to_numpy +from ldm_patched.modules.sd import load_checkpoint_guess_config +from ldm_patched.contrib.external_stable3d import StableZero123_Conditioning +from ldm_patched.contrib.external import KSampler, VAEDecode + + +opStableZero123_Conditioning = StableZero123_Conditioning() +opKSampler = KSampler() +opVAEDecode = VAEDecode() + +svd_root = os.path.join(models_path, 'z123') +os.makedirs(svd_root, exist_ok=True) +svd_filenames = [] + + +def update_svd_filenames(): + global svd_filenames + svd_filenames = [ + pathlib.Path(x).name for x in + shared.walk_files(svd_root, allowed_extensions=[".pt", ".ckpt", ".safetensors"]) + ] + return svd_filenames + + +@torch.inference_mode() +@torch.no_grad() +def predict(filename, width, height, batch_size, elevation, azimuth, + sampling_seed, sampling_steps, sampling_cfg, sampling_sampler_name, sampling_scheduler, sampling_denoise, input_image): + filename = os.path.join(svd_root, filename) + model, _, vae, clip_vision = \ + load_checkpoint_guess_config(filename, output_vae=True, output_clip=False, output_clipvision=True) + init_image = numpy_to_pytorch(input_image) + positive, negative, latent_image = opStableZero123_Conditioning.encode( + clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth) + output_latent = opKSampler.sample(model, sampling_seed, sampling_steps, sampling_cfg, + sampling_sampler_name, sampling_scheduler, positive, + negative, latent_image, sampling_denoise)[0] + output_pixels = opVAEDecode.decode(vae, output_latent)[0] + outputs = pytorch_to_numpy(output_pixels) + return outputs + + +def on_ui_tabs(): + with gr.Blocks() as svd_block: + with gr.Row(): + with gr.Column(): + input_image = gr.Image(label='Input Image', source='upload', type='numpy', height=400) + + with gr.Row(): + filename = gr.Dropdown(label="SVD Checkpoint Filename", + choices=svd_filenames, + value=svd_filenames[0] if len(svd_filenames) > 0 else None) + refresh_button = ToolButton(value=refresh_symbol, tooltip="Refresh") + refresh_button.click( + fn=lambda: gr.update(choices=update_svd_filenames), + inputs=[], outputs=filename) + + width = gr.Slider(label='Width', minimum=16, maximum=8192, step=8, value=256) + height = gr.Slider(label='Height', minimum=16, maximum=8192, step=8, value=256) + batch_size = gr.Slider(label='Batch Size', minimum=1, maximum=4096, step=1, value=4) + elevation = gr.Slider(label='Elevation', minimum=-180.0, maximum=180.0, step=0.001, value=10.0) + azimuth = gr.Slider(label='Azimuth', minimum=-180.0, maximum=180.0, step=0.001, value=142.0) + sampling_denoise = gr.Slider(label='Sampling Denoise', minimum=0.0, maximum=1.0, step=0.01, value=1.0) + sampling_steps = gr.Slider(label='Sampling Steps', minimum=1, maximum=10000, step=1, value=20) + sampling_cfg = gr.Slider(label='CFG Scale', minimum=0.0, maximum=100.0, step=0.1, value=5.0) + sampling_sampler_name = gr.Radio(label='Sampling Sampler Name', + choices=['euler', 'euler_ancestral', 'heun', 'heunpp2', 'dpm_2', + 'dpm_2_ancestral', 'lms', 'dpm_fast', 'dpm_adaptive', + 'dpmpp_2s_ancestral', 'dpmpp_sde', 'dpmpp_sde_gpu', + 'dpmpp_2m', 'dpmpp_2m_sde', 'dpmpp_2m_sde_gpu', + 'dpmpp_3m_sde', 'dpmpp_3m_sde_gpu', 'ddpm', 'lcm', 'ddim', + 'uni_pc', 'uni_pc_bh2'], value='euler') + sampling_scheduler = gr.Radio(label='Sampling Scheduler', + choices=['normal', 'karras', 'exponential', 'sgm_uniform', 'simple', + 'ddim_uniform'], value='sgm_uniform') + sampling_seed = gr.Number(label='Seed', value=12345, precision=0) + generate_button = gr.Button(value="Generate") + + ctrls = [filename, width, height, batch_size, elevation, azimuth, sampling_seed, sampling_steps, sampling_cfg, sampling_sampler_name, sampling_scheduler, sampling_denoise, input_image] + + with gr.Column(): + output_gallery = gr.Gallery(label='Gallery', show_label=False, object_fit='contain', + visible=True, height=1024, columns=4) + + generate_button.click(predict, inputs=ctrls, outputs=[output_gallery]) + return [(svd_block, "Z123", "z123")] + + +update_svd_filenames() +script_callbacks.on_ui_tabs(on_ui_tabs)