Files
webui-fooocus-prompt-expansion/scripts/expansion.py
2024-06-09 17:59:41 +09:00

267 lines
9.6 KiB
Python

# Fooocus GPT2 Expansion
# Algorithm created by Lvmin Zhang at 2023, Stanford
# Modified by power88 and GPT-4o for stable-diffusion-webui
# If used inside Fooocus, any use is permitted.
# If used outside Fooocus, only non-commercial use is permitted (CC-By NC 4.0).
# This applies to the word list, vocab, model, and algorithm.
import os
import torch
import math
import shutil
import gradio as gr
import psutil
from pathlib import Path
from modules.scripts import basedir
from huggingface_hub import hf_hub_download
from transformers.generation.logits_process import LogitsProcessorList
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from modules import scripts, paths_internal, errors, shared, script_callbacks
from modules.ui_components import InputAccordion
def text_encoder_device():
if torch.cuda.is_available():
return torch.device(torch.cuda.current_device())
else:
return torch.device("cpu")
def text_encoder_offload_device():
if torch.cuda.is_available():
return torch.device(torch.cuda.current_device())
else:
return torch.device("cpu")
def get_free_memory(dev=None, torch_free_too=False):
global directml_enabled
if dev is None:
dev = text_encoder_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else:
if directml_enabled:
mem_free_total = 1024 * 1024 * 1024 # TODO
mem_free_torch = mem_free_total
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
# limitation of np.random.seed(), called from transformers.set_seed()
SEED_LIMIT_NUMPY = 2 ** 32
neg_inf = - 8192.0
ext_dir = Path(basedir())
fooocus_expansion_model_dir = Path(paths_internal.models_path) / "prompt_expansion"
def download_model():
fooocus_expansion_model = fooocus_expansion_model_dir / "pytorch_model.bin"
if not fooocus_expansion_model.exists():
try:
print(f'### webui-fooocus-prompt-expansion: Downloading model...')
shutil.copytree(ext_dir / "models", fooocus_expansion_model_dir)
hf_hub_download(repo_id='lllyasviel/misc', filename='fooocus_expansion.bin', local_dir=fooocus_expansion_model_dir)
os.rename(fooocus_expansion_model_dir / 'fooocus_expansion.bin', fooocus_expansion_model)
except Exception:
errors.report('### webui-fooocus-prompt-expansion: Failed to download model', exc_info=True)
print(f'Download the model manually from "https://huggingface.co/lllyasviel/misc/tree/main/fooocus_expansion.bin" and place it in {fooocus_expansion_model_dir}.')
def safe_str(x):
x = str(x)
for _ in range(16):
x = x.replace(' ', ' ')
return x.strip(",. \r\n")
def remove_pattern(x, pattern):
for p in pattern:
x = x.replace(p, '')
return x
def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
if device is not None:
if hasattr(device, 'type'):
if device.type == 'cpu':
return False
return False
if torch.cuda.is_bf16_supported():
return True
props = torch.cuda.get_device_properties("cuda")
if props.major < 6:
return False
fp16_works = False
# FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
# when the model doesn't actually fit on the card
# TODO: actually test if GP106 and others have the same type of behavior
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050"]
for x in nvidia_10_series:
if x in props.name.lower():
fp16_works = True
if fp16_works:
free_model_memory = (get_free_memory() * 0.9 - (1024 * 1024 * 1024))
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
if props.major < 7:
return False
# FP16 is just broken on these cards
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
for x in nvidia_16_series:
if x in props.name:
return False
return True
def is_device_mps(device):
if hasattr(device, 'type'):
if (device.type == 'mps'):
return True
return False
class FooocusExpansion:
def __init__(self):
global load_model_device
download_model()
print(f'Loading models from {fooocus_expansion_model_dir}')
self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_model_dir)
positive_words = open(os.path.join(fooocus_expansion_model_dir, 'positive.txt'),
encoding='utf-8').read().splitlines()
positive_words = ['Ġ' + x.lower() for x in positive_words if x != '']
self.logits_bias = torch.zeros((1, len(self.tokenizer.vocab)), dtype=torch.float32) + neg_inf
debug_list = []
for k, v in self.tokenizer.vocab.items():
if k in positive_words:
self.logits_bias[0, v] = 0
debug_list.append(k[1:])
print(f'Fooocus V2 Expansion: Vocab with {len(debug_list)} words.')
self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_model_dir)
self.model.eval()
load_model_device = text_encoder_device()
offload_device = text_encoder_offload_device()
# MPS hack
if is_device_mps(load_model_device):
load_model_device = torch.device('cpu')
offload_device = torch.device('cpu')
use_fp16 = should_use_fp16(device=load_model_device)
if use_fp16:
self.model.half()
self.model.to(load_model_device) # Ensure model is on the correct device
print(f'Fooocus Expansion engine loaded for {load_model_device}, use_fp16 = {use_fp16}.')
def unload_model(self):
"""Unload the model to free up memory."""
del self.model
torch.cuda.empty_cache()
print('Model unloaded and memory cleared.')
@torch.no_grad()
@torch.inference_mode()
def logits_processor(self, input_ids, scores):
assert scores.ndim == 2 and scores.shape[0] == 1
self.logits_bias = self.logits_bias.to(load_model_device)
bias = self.logits_bias.clone().to(load_model_device) # Ensure bias is on the correct device
bias[0, input_ids[0].to(load_model_device).long()] = neg_inf # Ensure input_ids are on the correct device
bias[0, 11] = 0
return scores + bias.to(scores.device) # Ensure bias is on the same device as scores
@torch.no_grad()
@torch.inference_mode()
def __call__(self, prompt, seed):
if prompt == '':
return ''
seed = int(seed) % SEED_LIMIT_NUMPY
set_seed(seed)
prompt = safe_str(prompt) + ','
tokenized_kwargs = self.tokenizer(prompt, return_tensors="pt")
tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(load_model_device)
tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(load_model_device)
current_token_length = int(tokenized_kwargs.data['input_ids'].shape[1])
max_token_length = 75 * int(math.ceil(float(current_token_length) / 75.0))
max_new_tokens = max_token_length - current_token_length
features = self.model.generate(**tokenized_kwargs,
top_k=100,
max_new_tokens=max_new_tokens,
do_sample=True,
logits_processor=LogitsProcessorList([self.logits_processor]))
response = self.tokenizer.batch_decode(features, skip_special_tokens=True)
result = safe_str(response[0])
return result
def createPositive(positive, seed):
try:
expansion = FooocusExpansion()
positive = expansion(positive, seed=seed)
expansion.unload_model() # Unload the model after use
return positive
except Exception as e:
print(f"An error occurred: {str(e)}")
class FooocusPromptExpansion(scripts.Script):
infotext_fields = []
def title(self):
return 'Fooocus Prompt Expansion'
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with InputAccordion(False, label="Fooocus Expansion") as is_enabled:
seed = gr.Number(value=0, maximum=63, label="Seed", info="Seed for random number generator")
self.infotext_fields.append((is_enabled, lambda d: False))
return [is_enabled, seed]
def process(self, p, is_enabled, seed):
if not is_enabled:
return
for i, prompt in enumerate(p.all_prompts):
positivePrompt = createPositive(prompt, seed)
p.all_prompts[i] = positivePrompt
def after_component(self, component, **kwargs):
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/7456#issuecomment-1414465888 helpfull link
# Find the text2img textbox component
if kwargs.get("elem_id") == "txt2img_prompt": # postive prompt textbox
self.boxx = component
# Find the img2img textbox component
if kwargs.get("elem_id") == "img2img_prompt": # postive prompt textbox
self.boxxIMG = component