Extra modules (except SD) now use CPU by default

This commit is contained in:
SillyLossy
2023-06-18 21:51:31 +03:00
parent 94ab92972c
commit db0b232903
2 changed files with 7 additions and 6 deletions

View File

@@ -54,7 +54,7 @@ parser.add_argument(
parser.add_argument(
"--share", action="store_true", help="Share the app on CloudFlare tunnel"
)
parser.add_argument("--cpu", action="store_true", help="Run the models on the CPU")
parser.add_argument("--cpu", action="store_true", help="Run the models on the CPU", default=True)
parser.add_argument("--summarization-model", help="Load a custom summarization model")
parser.add_argument(
"--classification-model", help="Load a custom text classification model"
@@ -73,7 +73,7 @@ sd_group = parser.add_mutually_exclusive_group()
local_sd = sd_group.add_argument_group("sd-local")
local_sd.add_argument("--sd-model", help="Load a custom SD image generation model")
local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU")
local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU", action="store_true")
remote_sd = sd_group.add_argument_group("sd-remote")
remote_sd.add_argument(
@@ -144,6 +144,8 @@ device_string = "cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu
device = torch.device(device_string)
torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
if "caption" in modules:
print("Initializing an image captioning model...")
captioning_processor = AutoProcessor.from_pretrained(captioning_model)
@@ -248,7 +250,7 @@ if "chromadb" in modules:
chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False, persist_directory=args.chroma_folder, chroma_db_impl='duckdb+parquet'))
print(f"ChromaDB is running in-memory with persistence. Persistence is stored in {args.chroma_folder}. Can be cleared by deleting the folder or purging db.")
else:
chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
print(f"ChromaDB is running in-memory without persistence.")
else:
chroma_port=(