mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-26 17:39:02 +00:00
First pass at Apple Silicon MPS back end
This commit is contained in:
16
requirements-silicon.txt
Normal file
16
requirements-silicon.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
flask
|
||||
flask-cloudflared
|
||||
flask-cors
|
||||
flask-compress
|
||||
markdown
|
||||
Pillow
|
||||
colorama
|
||||
torch
|
||||
transformers==4.28.1
|
||||
webuiapi
|
||||
edge-tts
|
||||
silero-api-server
|
||||
torchvision
|
||||
torchaudio
|
||||
diffusers
|
||||
accelerate
|
||||
20
server.py
20
server.py
@@ -57,6 +57,7 @@ parser.add_argument(
|
||||
parser.add_argument("--cpu", action="store_true", help="Run the models on the CPU")
|
||||
parser.add_argument("--cuda", action="store_false", dest="cpu", help="Run the models on the GPU")
|
||||
parser.add_argument("--cuda-device", help="Specify the CUDA device to use")
|
||||
parser.add_argument("--mps", "--apple", "--m1", "--m2", action="store_false", dest="cpu", help="Run the models on Apple Silicon")
|
||||
parser.set_defaults(cpu=True)
|
||||
parser.add_argument("--summarization-model", help="Load a custom summarization model")
|
||||
parser.add_argument(
|
||||
@@ -67,7 +68,7 @@ parser.add_argument("--embedding-model", help="Load a custom text embedding mode
|
||||
parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
|
||||
parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
|
||||
parser.add_argument("--chroma-folder", help="Path for chromadb persistence folder", default='.chroma_db')
|
||||
parser.add_argument('--chroma-persist', help="Chromadb persistence", default=True, action=argparse.BooleanOptionalAction)
|
||||
parser.add_argument('--chroma-persist', help="ChromaDB persistence", default=True, action=argparse.BooleanOptionalAction)
|
||||
parser.add_argument(
|
||||
"--secure", action="store_true", help="Enforces the use of an API key"
|
||||
)
|
||||
@@ -144,12 +145,15 @@ if len(modules) == 0:
|
||||
|
||||
# Models init
|
||||
cuda_device = DEFAULT_CUDA_DEVICE if not args.cuda_device else args.cuda_device
|
||||
device_string = cuda_device if torch.cuda.is_available() and not args.cpu else "cpu"
|
||||
device_string = cuda_device if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
||||
device = torch.device(device_string)
|
||||
torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
|
||||
torch_dtype = torch.float32 if device_string != "cuda:0" else torch.float16
|
||||
|
||||
if not torch.cuda.is_available() and not args.cpu:
|
||||
print(f"{Fore.YELLOW}{Style.BRIGHT}torch-cuda is not supported on this device. Defaulting to CPU mode.{Style.RESET_ALL}")
|
||||
print(f"{Fore.YELLOW}{Style.BRIGHT}torch-cuda is not supported on this device.{Style.RESET_ALL}")
|
||||
if not torch.backends.mps.is_available() and not args.cpu:
|
||||
print(f"{Fore.YELLOW}{Style.BRIGHT}torch-mps is not supported on this device.{Style.RESET_ALL}")
|
||||
|
||||
|
||||
print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
|
||||
|
||||
@@ -186,12 +190,10 @@ if "sd" in modules and not sd_use_remote:
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
|
||||
print("Initializing Stable Diffusion pipeline")
|
||||
sd_device_string = (
|
||||
"cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu"
|
||||
)
|
||||
print("Initializing Stable Diffusion pipeline...")
|
||||
sd_device_string = cuda_device if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
||||
sd_device = torch.device(sd_device_string)
|
||||
sd_torch_dtype = torch.float32 if sd_device_string == "cpu" else torch.float16
|
||||
sd_torch_dtype = torch.float32 if sd_device_string != "cpu" else torch.float16
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained(
|
||||
sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype
|
||||
).to(sd_device)
|
||||
|
||||
Reference in New Issue
Block a user