feature: implemented the ability to use a remote sd backend

This commit is contained in:
Babnam
2023-04-11 16:16:17 +02:00
parent 8c3f1f06f7
commit bda0aec777
2 changed files with 50 additions and 9 deletions

View File

@@ -7,3 +7,4 @@ colorama
--extra-index-url https://download.pytorch.org/whl/cu117
torch==2.0.0+cu117
git+https://github.com/huggingface/transformers
webuiapi

View File

@@ -13,6 +13,7 @@ from PIL import Image
import base64
from io import BytesIO
from random import randint
import webuiapi
from colorama import Fore, Style, init as colorama_init
colorama_init()
@@ -28,6 +29,8 @@ DEFAULT_CAPTIONING_MODEL = 'Salesforce/blip-image-captioning-large'
DEFAULT_KEYPHRASE_MODEL = 'ml6team/keyphrase-extraction-distilbert-inspec'
DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2'
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
DEFAULT_REMOTE_SD_PORT = 7860
#ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
DEFAULT_SUMMARIZE_PARAMS = {
'temperature': 1.0,
@@ -63,10 +66,25 @@ parser.add_argument('--keyphrase-model',
help="Load a custom keyphrase extraction model")
parser.add_argument('--prompt-model',
help="Load a custom prompt generation model")
parser.add_argument('--sd-model',
sd_group = parser.add_mutually_exclusive_group()
local_sd = sd_group.add_argument_group('sd-local')
local_sd.add_argument('--sd-model',
help="Load a custom SD image generation model")
parser.add_argument('--sd-cpu',
local_sd.add_argument('--sd-cpu',
help="Force the SD pipeline to run on the CPU")
remote_sd = sd_group.add_argument_group('sd-remote')
sd_group.add_argument('--sd-remote', action='store_true',
help="Use a remote SD API")
sd_group.add_argument('--sd-remote-host', type=str,
help="Specify the remote SD API host")
sd_group.add_argument('--sd-remote-port', type=int,
help="Specify the remote SD API port")
sd_group.add_argument('--sd-remote-username', type=str,
help="Specify the remote SD API username")
sd_group.add_argument('--sd-remote-password', type=str,
help="Specify the remote SD API password")
parser.add_argument('--enable-modules', action=SplitArgs, default=[],
help="Override a list of enabled modules")
@@ -79,7 +97,12 @@ classification_model = args.classification_model if args.classification_model el
captioning_model = args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
keyphrase_model = args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
sd_use_remote = True if args.sd_remote or args.sd_remote_host else False
sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
sd_remote_host = args.sd_remote_host if args.sd_remote_host else DEFAULT_REMOTE_SD_HOST
sd_remote_port = args.sd_remote_port if args.sd_remote_port else DEFAULT_REMOTE_SD_PORT
modules = args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else []
if len(modules) == 0:
@@ -120,7 +143,7 @@ if 'prompt' in modules:
gpt_model = AutoModelForCausalLM.from_pretrained(prompt_model)
prompt_generator = pipeline('text-generation', model=gpt_model, tokenizer=gpt_tokenizer)
if 'sd' in modules:
if 'sd' in modules and not sd_use_remote:
from diffusers import StableDiffusionPipeline
from diffusers import EulerAncestralDiscreteScheduler
print('Initializing Stable Diffusion pipeline')
@@ -132,6 +155,14 @@ if 'sd' in modules:
sd_pipe.enable_attention_slicing()
# pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
elif 'sd' in modules and sd_use_remote:
print('Initializing remote SD pipeline')
sd_remote = webuiapi.WebUIApi(host=sd_remote_host, port=sd_remote_port, use_https=False, sampler='Euler a', steps=30)
if args.sd_remote_username and args.sd_remote_password:
sd_remote.set_auth(args.sd_remote_username, args.sd_remote_password)
sd_remote.util_set_model('Anything')
sd_remote.util_wait_for_ready()
prompt_prefix = "best quality, absurdres, "
neg_prompt = """lowres, bad anatomy, error body, error hair, error arm,
@@ -230,12 +261,21 @@ def generate_image(input: str, steps: int = 30, scale: int = 6) -> Image:
prompt = normalize_string(f'{prompt_prefix}{input}')
print(prompt)
image = sd_pipe(
prompt=prompt,
negative_prompt=neg_prompt,
num_inference_steps=steps,
guidance_scale=scale,
).images[0]
if sd_use_remote:
image = sd_remote.txt2img(
prompt=prompt,
negative_prompt=neg_prompt,
sampler_index='Euler a',
steps=steps,
cfg_scale=scale,
).image
else:
image = sd_pipe(
prompt=prompt,
negative_prompt=neg_prompt,
num_inference_steps=steps,
guidance_scale=scale,
).images[0]
image.save("./debug.png")
return image