diff --git a/modules_forge/initialization.py b/modules_forge/initialization.py index 02914592..7d475a26 100644 --- a/modules_forge/initialization.py +++ b/modules_forge/initialization.py @@ -7,9 +7,19 @@ def initialize_forge(): "Use this when you ara on MAC or have more than 20GB VRAM like RTX4096.") args_parser.args = args_parser.parser.parse_known_args()[0] - args_parser.args.always_offload_from_vram = not args_parser.args.disable_offload_from_vram import ldm_patched.modules.model_management as model_management + + if args_parser.args.disable_offload_from_vram: + print('User disabled VRAM offload.') + model_management.ALWAYS_VRAM_OFFLOAD = False + elif model_management.total_vram > 20 * 1024: + print('Automatically disable VRAM offload since user have more than 20GB VRAM.') + model_management.ALWAYS_VRAM_OFFLOAD = False + else: + print('Always offload models from VRAM.') + model_management.ALWAYS_VRAM_OFFLOAD = True + import torch device = model_management.get_torch_device()