From 9ef2c4aec500a8641830109fa18a6dfe46b3ffaf Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Fri, 26 Jan 2024 08:26:30 -0800 Subject: [PATCH] Update forge_z123.py --- .../sd_forge_z123/scripts/forge_z123.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/extensions-builtin/sd_forge_z123/scripts/forge_z123.py b/extensions-builtin/sd_forge_z123/scripts/forge_z123.py index c0abe39c..e67ad325 100644 --- a/extensions-builtin/sd_forge_z123/scripts/forge_z123.py +++ b/extensions-builtin/sd_forge_z123/scripts/forge_z123.py @@ -18,25 +18,25 @@ opStableZero123_Conditioning = StableZero123_Conditioning() opKSampler = KSampler() opVAEDecode = VAEDecode() -svd_root = os.path.join(models_path, 'z123') -os.makedirs(svd_root, exist_ok=True) -svd_filenames = [] +model_root = os.path.join(models_path, 'z123') +os.makedirs(model_root, exist_ok=True) +model_filenames = [] -def update_svd_filenames(): - global svd_filenames - svd_filenames = [ +def update_model_filenames(): + global model_filenames + model_filenames = [ pathlib.Path(x).name for x in - shared.walk_files(svd_root, allowed_extensions=[".pt", ".ckpt", ".safetensors"]) + shared.walk_files(model_root, allowed_extensions=[".pt", ".ckpt", ".safetensors"]) ] - return svd_filenames + return model_filenames @torch.inference_mode() @torch.no_grad() def predict(filename, width, height, batch_size, elevation, azimuth, sampling_seed, sampling_steps, sampling_cfg, sampling_sampler_name, sampling_scheduler, sampling_denoise, input_image): - filename = os.path.join(svd_root, filename) + filename = os.path.join(model_root, filename) model, _, vae, clip_vision = \ load_checkpoint_guess_config(filename, output_vae=True, output_clip=False, output_clipvision=True) init_image = numpy_to_pytorch(input_image) @@ -51,18 +51,18 @@ def predict(filename, width, height, batch_size, elevation, azimuth, def on_ui_tabs(): - with gr.Blocks() as svd_block: + with gr.Blocks() as model_block: with gr.Row(): with gr.Column(): input_image = gr.Image(label='Input Image', source='upload', type='numpy', height=400) with gr.Row(): - filename = gr.Dropdown(label="SVD Checkpoint Filename", - choices=svd_filenames, - value=svd_filenames[0] if len(svd_filenames) > 0 else None) + filename = gr.Dropdown(label="Zero123 Checkpoint Filename", + choices=model_filenames, + value=model_filenames[0] if len(model_filenames) > 0 else None) refresh_button = ToolButton(value=refresh_symbol, tooltip="Refresh") refresh_button.click( - fn=lambda: gr.update(choices=update_svd_filenames), + fn=lambda: gr.update(choices=update_model_filenames), inputs=[], outputs=filename) width = gr.Slider(label='Width', minimum=16, maximum=8192, step=8, value=256) @@ -93,8 +93,8 @@ def on_ui_tabs(): visible=True, height=1024, columns=4) generate_button.click(predict, inputs=ctrls, outputs=[output_gallery]) - return [(svd_block, "Z123", "z123")] + return [(model_block, "Z123", "z123")] -update_svd_filenames() +update_model_filenames() script_callbacks.on_ui_tabs(on_ui_tabs)