mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-20 23:03:58 +00:00
Update forge_z123.py
This commit is contained in:
@@ -18,25 +18,25 @@ opStableZero123_Conditioning = StableZero123_Conditioning()
|
||||
opKSampler = KSampler()
|
||||
opVAEDecode = VAEDecode()
|
||||
|
||||
svd_root = os.path.join(models_path, 'z123')
|
||||
os.makedirs(svd_root, exist_ok=True)
|
||||
svd_filenames = []
|
||||
model_root = os.path.join(models_path, 'z123')
|
||||
os.makedirs(model_root, exist_ok=True)
|
||||
model_filenames = []
|
||||
|
||||
|
||||
def update_svd_filenames():
|
||||
global svd_filenames
|
||||
svd_filenames = [
|
||||
def update_model_filenames():
|
||||
global model_filenames
|
||||
model_filenames = [
|
||||
pathlib.Path(x).name for x in
|
||||
shared.walk_files(svd_root, allowed_extensions=[".pt", ".ckpt", ".safetensors"])
|
||||
shared.walk_files(model_root, allowed_extensions=[".pt", ".ckpt", ".safetensors"])
|
||||
]
|
||||
return svd_filenames
|
||||
return model_filenames
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
def predict(filename, width, height, batch_size, elevation, azimuth,
|
||||
sampling_seed, sampling_steps, sampling_cfg, sampling_sampler_name, sampling_scheduler, sampling_denoise, input_image):
|
||||
filename = os.path.join(svd_root, filename)
|
||||
filename = os.path.join(model_root, filename)
|
||||
model, _, vae, clip_vision = \
|
||||
load_checkpoint_guess_config(filename, output_vae=True, output_clip=False, output_clipvision=True)
|
||||
init_image = numpy_to_pytorch(input_image)
|
||||
@@ -51,18 +51,18 @@ def predict(filename, width, height, batch_size, elevation, azimuth,
|
||||
|
||||
|
||||
def on_ui_tabs():
|
||||
with gr.Blocks() as svd_block:
|
||||
with gr.Blocks() as model_block:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
input_image = gr.Image(label='Input Image', source='upload', type='numpy', height=400)
|
||||
|
||||
with gr.Row():
|
||||
filename = gr.Dropdown(label="SVD Checkpoint Filename",
|
||||
choices=svd_filenames,
|
||||
value=svd_filenames[0] if len(svd_filenames) > 0 else None)
|
||||
filename = gr.Dropdown(label="Zero123 Checkpoint Filename",
|
||||
choices=model_filenames,
|
||||
value=model_filenames[0] if len(model_filenames) > 0 else None)
|
||||
refresh_button = ToolButton(value=refresh_symbol, tooltip="Refresh")
|
||||
refresh_button.click(
|
||||
fn=lambda: gr.update(choices=update_svd_filenames),
|
||||
fn=lambda: gr.update(choices=update_model_filenames),
|
||||
inputs=[], outputs=filename)
|
||||
|
||||
width = gr.Slider(label='Width', minimum=16, maximum=8192, step=8, value=256)
|
||||
@@ -93,8 +93,8 @@ def on_ui_tabs():
|
||||
visible=True, height=1024, columns=4)
|
||||
|
||||
generate_button.click(predict, inputs=ctrls, outputs=[output_gallery])
|
||||
return [(svd_block, "Z123", "z123")]
|
||||
return [(model_block, "Z123", "z123")]
|
||||
|
||||
|
||||
update_svd_filenames()
|
||||
update_model_filenames()
|
||||
script_callbacks.on_ui_tabs(on_ui_tabs)
|
||||
|
||||
Reference in New Issue
Block a user