Update forge_z123.py

This commit is contained in:
lllyasviel
2024-01-26 08:26:30 -08:00
parent 30023256e4
commit 9ef2c4aec5

View File

@@ -18,25 +18,25 @@ opStableZero123_Conditioning = StableZero123_Conditioning()
opKSampler = KSampler()
opVAEDecode = VAEDecode()
svd_root = os.path.join(models_path, 'z123')
os.makedirs(svd_root, exist_ok=True)
svd_filenames = []
model_root = os.path.join(models_path, 'z123')
os.makedirs(model_root, exist_ok=True)
model_filenames = []
def update_svd_filenames():
global svd_filenames
svd_filenames = [
def update_model_filenames():
global model_filenames
model_filenames = [
pathlib.Path(x).name for x in
shared.walk_files(svd_root, allowed_extensions=[".pt", ".ckpt", ".safetensors"])
shared.walk_files(model_root, allowed_extensions=[".pt", ".ckpt", ".safetensors"])
]
return svd_filenames
return model_filenames
@torch.inference_mode()
@torch.no_grad()
def predict(filename, width, height, batch_size, elevation, azimuth,
sampling_seed, sampling_steps, sampling_cfg, sampling_sampler_name, sampling_scheduler, sampling_denoise, input_image):
filename = os.path.join(svd_root, filename)
filename = os.path.join(model_root, filename)
model, _, vae, clip_vision = \
load_checkpoint_guess_config(filename, output_vae=True, output_clip=False, output_clipvision=True)
init_image = numpy_to_pytorch(input_image)
@@ -51,18 +51,18 @@ def predict(filename, width, height, batch_size, elevation, azimuth,
def on_ui_tabs():
with gr.Blocks() as svd_block:
with gr.Blocks() as model_block:
with gr.Row():
with gr.Column():
input_image = gr.Image(label='Input Image', source='upload', type='numpy', height=400)
with gr.Row():
filename = gr.Dropdown(label="SVD Checkpoint Filename",
choices=svd_filenames,
value=svd_filenames[0] if len(svd_filenames) > 0 else None)
filename = gr.Dropdown(label="Zero123 Checkpoint Filename",
choices=model_filenames,
value=model_filenames[0] if len(model_filenames) > 0 else None)
refresh_button = ToolButton(value=refresh_symbol, tooltip="Refresh")
refresh_button.click(
fn=lambda: gr.update(choices=update_svd_filenames),
fn=lambda: gr.update(choices=update_model_filenames),
inputs=[], outputs=filename)
width = gr.Slider(label='Width', minimum=16, maximum=8192, step=8, value=256)
@@ -93,8 +93,8 @@ def on_ui_tabs():
visible=True, height=1024, columns=4)
generate_button.click(predict, inputs=ctrls, outputs=[output_gallery])
return [(svd_block, "Z123", "z123")]
return [(model_block, "Z123", "z123")]
update_svd_filenames()
update_model_filenames()
script_callbacks.on_ui_tabs(on_ui_tabs)