diff --git a/launch.py b/launch.py index f83820d2..10aa5463 100644 --- a/launch.py +++ b/launch.py @@ -41,6 +41,9 @@ def main(): if args.test_server: configure_for_tests() + if args.forge_ref_a1111_home: + launch_utils.configure_forge_reference_checkout(args.forge_ref_a1111_home) + start() diff --git a/modules/cmd_args.py b/modules/cmd_args.py index ef2ecd7b..0aa8ef78 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -2,6 +2,7 @@ import argparse import json import os from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401 +from pathlib import Path from ldm_patched.modules import args_parser parser = args_parser.parser @@ -122,3 +123,23 @@ parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set time parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) + +# Arguments added by forge. +parser.add_argument( + '--forge-ref-a1111-home', + type=Path, + help="Look for models in an existing A1111 checkout's path", + default=None +) +parser.add_argument( + "--controlnet-dir", + type=Path, + help="Path to directory with ControlNet models", + default=None, +) +parser.add_argument( + "--controlnet-preprocessor-models-dir", + type=Path, + help="Path to directory with annotator model directories", + default=None, +) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 409385a8..48be47f4 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -10,6 +10,8 @@ import importlib.metadata import platform import json from functools import lru_cache +from typing import NamedTuple +from pathlib import Path from modules import cmd_args, errors from modules.paths_internal import script_path, extensions_dir, extensions_builtin_dir @@ -503,6 +505,37 @@ def configure_for_tests(): os.environ['COMMANDLINE_ARGS'] = "" +def configure_forge_reference_checkout(a1111_home: Path): + """Set model paths based on an existing A1111 checkout.""" + class ModelRef(NamedTuple): + arg_name: str + relative_path: str + + refs = [ + ModelRef(arg_name="--ckpt-dir", relative_path="models/Stable-diffusion"), + ModelRef(arg_name="--vae-dir", relative_path="models/VAE"), + ModelRef(arg_name="--hypernetwork-dir", relative_path="models/hypernetworks"), + ModelRef(arg_name="--embeddings-dir", relative_path="models/embeddings"), + ModelRef(arg_name="--lora-dir", relative_path="models/lora"), + # Ref A1111 need to have sd-webui-controlnet installed. + ModelRef(arg_name="--controlnet-dir", relative_path="models/ControlNet"), + ModelRef(arg_name="--controlnet-preprocessor-models-dir", relative_path="extensions/sd-webui-controlnet/annotator/downloads"), + ] + + for ref in refs: + target_path = a1111_home / ref.relative_path + if not target_path.exists(): + print(f"Path {target_path} does not exist. Skip setting {ref.arg_name}") + continue + + if ref.arg_name in sys.argv: + # Do not override existing dir setting. + continue + + sys.argv.append(ref.arg_name) + sys.argv.append(str(target_path)) + + def start(): print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}") import webui diff --git a/modules_forge/shared.py b/modules_forge/shared.py index 8b03e788..ec44c9a4 100644 --- a/modules_forge/shared.py +++ b/modules_forge/shared.py @@ -2,12 +2,18 @@ import os import ldm_patched.modules.utils from modules.paths_internal import models_path +from modules.shared import cmd_opts - -controlnet_dir = os.path.join(models_path, 'ControlNet') +if cmd_opts.controlnet_dir: + controlnet_dir = str(cmd_opts.controlnet_dir) +else: + controlnet_dir = os.path.join(models_path, 'ControlNet') os.makedirs(controlnet_dir, exist_ok=True) -preprocessor_dir = os.path.join(models_path, 'ControlNetPreprocessor') +if cmd_opts.controlnet_preprocessor_models_dir: + preprocessor_dir = str(cmd_opts.controlnet_preprocessor_models_dir) +else: + preprocessor_dir = os.path.join(models_path, 'ControlNetPreprocessor') os.makedirs(preprocessor_dir, exist_ok=True) diffusers_dir = os.path.join(models_path, 'diffusers')