diff --git a/modules_forge/initialization.py b/modules_forge/initialization.py index 27f5fd6f..98bd6d2e 100644 --- a/modules_forge/initialization.py +++ b/modules_forge/initialization.py @@ -39,6 +39,10 @@ def initialize_forge(): args_parser.args, _ = args_parser.parser.parse_known_args() + if args_parser.args.gpu_device_id is not None: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args_parser.args.gpu_device_id) + print("Set device to:", args_parser.args.gpu_device_id) + import ldm_patched.modules.model_management as model_management import torch