From 8a383844bf9fede11f7197b357b7b98852ac6f9b Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 9 Jun 2024 17:27:53 +0900 Subject: [PATCH 01/12] .gitignore __pycache__/ --- .gitignore | 1 + 1 file changed, 1 insertion(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c18dd8d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__/ From 7486e593a5749c7f3a639c58d5a9ae62b36ba6ee Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 9 Jun 2024 17:28:50 +0900 Subject: [PATCH 02/12] reformat --- scripts/expansion.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/scripts/expansion.py b/scripts/expansion.py index 9109b8d..05ded02 100644 --- a/scripts/expansion.py +++ b/scripts/expansion.py @@ -25,12 +25,14 @@ def text_encoder_device(): else: return torch.device("cpu") + def text_encoder_offload_device(): if torch.cuda.is_available(): return torch.device(torch.cuda.current_device()) else: return torch.device("cpu") + def get_free_memory(dev=None, torch_free_too=False): global directml_enabled if dev is None: @@ -41,7 +43,7 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_torch = mem_free_total else: if directml_enabled: - mem_free_total = 1024 * 1024 * 1024 #TODO + mem_free_total = 1024 * 1024 * 1024 # TODO mem_free_torch = mem_free_total else: stats = torch.cuda.memory_stats(dev) @@ -52,9 +54,8 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_total = mem_free_cuda + mem_free_torch - # limitation of np.random.seed(), called from transformers.set_seed() -SEED_LIMIT_NUMPY = 2**32 +SEED_LIMIT_NUMPY = 2 ** 32 neg_inf = - 8192.0 ext_dir = basedir() path_fooocus_expansion = os.path.join('.', "models", "prompt_expansion") @@ -72,6 +73,7 @@ def remove_pattern(x, pattern): x = x.replace(p, '') return x + def should_use_fp16(device=None, model_params=0, prioritize_performance=True): if device is not None: if hasattr(device, 'type'): @@ -85,9 +87,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): return False fp16_works = False - #FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled - #when the model doesn't actually fit on the card - #TODO: actually test if GP106 and others have the same type of behavior + # FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled + # when the model doesn't actually fit on the card + # TODO: actually test if GP106 and others have the same type of behavior nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050"] for x in nvidia_10_series: if x in props.name.lower(): @@ -101,7 +103,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): if props.major < 7: return False - #FP16 is just broken on these cards + # FP16 is just broken on these cards nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"] for x in nvidia_16_series: if x in props.name: @@ -109,11 +111,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): return True + def is_device_mps(device): if hasattr(device, 'type'): if (device.type == 'mps'): return True - return False + return False + class FooocusExpansion: def __init__(self): @@ -201,6 +205,7 @@ class FooocusExpansion: return result + def createPositive(positive, seed): try: expansion = FooocusExpansion() @@ -210,16 +215,17 @@ def createPositive(positive, seed): except Exception as e: print(f"An error occurred: {str(e)}") + class FooocusPromptExpansion(scripts.Script): def __init__(self) -> None: super().__init__() - + def title(self): return 'Fooocus Prompt Expansion' - + def show(self, is_img2img): return scripts.AlwaysVisible - + def ui(self, is_img2img): with gr.Group(): with gr.Accordion("Fooocus Expansion", open=True): @@ -228,7 +234,7 @@ class FooocusPromptExpansion(scripts.Script): seed = gr.Number( value=0, maximum=63, label="Seed", info="Seed for random number generator") return [is_enabled, seed] - + def process(self, p, is_enabled, seed): if not is_enabled: return @@ -237,8 +243,6 @@ class FooocusPromptExpansion(scripts.Script): positivePrompt = createPositive(prompt, seed) p.all_prompts[i] = positivePrompt - - def after_component(self, component, **kwargs): # https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/7456#issuecomment-1414465888 helpfull link # Find the text2img textbox component From 979e5e69daf23f08691a165fc84b04ce548c5363 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 9 Jun 2024 17:31:04 +0900 Subject: [PATCH 03/12] fix models_path --- scripts/expansion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/expansion.py b/scripts/expansion.py index 05ded02..6416357 100644 --- a/scripts/expansion.py +++ b/scripts/expansion.py @@ -15,7 +15,7 @@ import psutil from modules.scripts import basedir from transformers.generation.logits_process import LogitsProcessorList from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed -from modules import scripts, shared, script_callbacks +from modules import scripts, paths_internal, shared, script_callbacks from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton @@ -58,7 +58,7 @@ def get_free_memory(dev=None, torch_free_too=False): SEED_LIMIT_NUMPY = 2 ** 32 neg_inf = - 8192.0 ext_dir = basedir() -path_fooocus_expansion = os.path.join('.', "models", "prompt_expansion") +path_fooocus_expansion = os.path.join(paths_internal.models_path, "prompt_expansion") def safe_str(x): From 483fdc433055b4b30c6b5a90d9224e93ebab7e20 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 9 Jun 2024 17:48:22 +0900 Subject: [PATCH 04/12] download the model on first use remove install.py --- install.py | 23 ----------------------- scripts/expansion.py | 31 ++++++++++++++++++++++++------- 2 files changed, 24 insertions(+), 30 deletions(-) delete mode 100644 install.py diff --git a/install.py b/install.py deleted file mode 100644 index 8edf1f5..0000000 --- a/install.py +++ /dev/null @@ -1,23 +0,0 @@ -import os -import pathlib -import shutil -from huggingface_hub import hf_hub_download -from modules.scripts import basedir - -ext_dir = basedir() -fooocus_expansion_path = pathlib.Path(ext_dir) / "models" / "prompt_expansion" -base_model_path = pathlib.Path(ext_dir) / "extensions" / "webui-fooocus-prompt-expansion" / "models" - - -if not os.path.exists(os.path.join(fooocus_expansion_path, 'pytorch_model.bin')): - try: - print(f'### webui-fooocus-prompt-expansion: Downloading model...') - shutil.copytree(os.path.join(base_model_path), fooocus_expansion_path) - hf_hub_download(repo_id='lllyasviel/misc', filename='fooocus_expansion.bin', local_dir=os.path.join(fooocus_expansion_path), resume_download=True, local_dir_use_symlinks=False) - os.rename(os.path.join(fooocus_expansion_path, 'fooocus_expansion.bin'), os.path.join(fooocus_expansion_path, 'pytorch_model.bin')) - except Exception as e: - print(f'### webui-fooocus-prompt-expansion: Failed to download model...') - print(e) - print(f'### webui-fooocus-prompt-expansion: To enable this custom node, please download the model manually from "https://huggingface.co/lllyasviel/misc/tree/main/fooocus_expansion.bin" and place it in {fooocus_expansion_path}.') -else: - pass diff --git a/scripts/expansion.py b/scripts/expansion.py index 6416357..ef2a096 100644 --- a/scripts/expansion.py +++ b/scripts/expansion.py @@ -9,13 +9,16 @@ import os import torch import math +import shutil import gradio as gr import psutil +from pathlib import Path from modules.scripts import basedir +from huggingface_hub import hf_hub_download from transformers.generation.logits_process import LogitsProcessorList from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed -from modules import scripts, paths_internal, shared, script_callbacks +from modules import scripts, paths_internal, errors, shared, script_callbacks from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton @@ -57,8 +60,21 @@ def get_free_memory(dev=None, torch_free_too=False): # limitation of np.random.seed(), called from transformers.set_seed() SEED_LIMIT_NUMPY = 2 ** 32 neg_inf = - 8192.0 -ext_dir = basedir() -path_fooocus_expansion = os.path.join(paths_internal.models_path, "prompt_expansion") +ext_dir = Path(basedir()) +fooocus_expansion_model_dir = Path(paths_internal.models_path) / "prompt_expansion" + + +def download_model(): + fooocus_expansion_model = fooocus_expansion_model_dir / "pytorch_model.bin" + if not fooocus_expansion_model.exists(): + try: + print(f'### webui-fooocus-prompt-expansion: Downloading model...') + shutil.copytree(ext_dir / "models", fooocus_expansion_model_dir) + hf_hub_download(repo_id='lllyasviel/misc', filename='fooocus_expansion.bin', local_dir=fooocus_expansion_model_dir) + os.rename(fooocus_expansion_model_dir / 'fooocus_expansion.bin', fooocus_expansion_model) + except Exception: + errors.report('### webui-fooocus-prompt-expansion: Failed to download model', exc_info=True) + print(f'Download the model manually from "https://huggingface.co/lllyasviel/misc/tree/main/fooocus_expansion.bin" and place it in {fooocus_expansion_model_dir}.') def safe_str(x): @@ -122,10 +138,11 @@ def is_device_mps(device): class FooocusExpansion: def __init__(self): global load_model_device - print(f'Loading models from {path_fooocus_expansion}') - self.tokenizer = AutoTokenizer.from_pretrained(path_fooocus_expansion) + download_model() + print(f'Loading models from {fooocus_expansion_model_dir}') + self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_model_dir) - positive_words = open(os.path.join(path_fooocus_expansion, 'positive.txt'), + positive_words = open(os.path.join(fooocus_expansion_model_dir, 'positive.txt'), encoding='utf-8').read().splitlines() positive_words = ['Ġ' + x.lower() for x in positive_words if x != ''] @@ -139,7 +156,7 @@ class FooocusExpansion: print(f'Fooocus V2 Expansion: Vocab with {len(debug_list)} words.') - self.model = AutoModelForCausalLM.from_pretrained(path_fooocus_expansion) + self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_model_dir) self.model.eval() load_model_device = text_encoder_device() From c59d9d103292b1d91cf1ed9abdc093978fa8e473 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 9 Jun 2024 17:51:05 +0900 Subject: [PATCH 05/12] use InputAccordion --- scripts/expansion.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/scripts/expansion.py b/scripts/expansion.py index ef2a096..efc1ffe 100644 --- a/scripts/expansion.py +++ b/scripts/expansion.py @@ -19,7 +19,7 @@ from huggingface_hub import hf_hub_download from transformers.generation.logits_process import LogitsProcessorList from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed from modules import scripts, paths_internal, errors, shared, script_callbacks -from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton +from modules.ui_components import InputAccordion def text_encoder_device(): @@ -244,12 +244,8 @@ class FooocusPromptExpansion(scripts.Script): return scripts.AlwaysVisible def ui(self, is_img2img): - with gr.Group(): - with gr.Accordion("Fooocus Expansion", open=True): - is_enabled = gr.Checkbox( - value=True, label="Enable Expansion", info="Enable Or Disable Expansion ") - seed = gr.Number( - value=0, maximum=63, label="Seed", info="Seed for random number generator") + with InputAccordion(False, label="Fooocus Expansion") as is_enabled: + seed = gr.Number(value=0, maximum=63, label="Seed", info="Seed for random number generator") return [is_enabled, seed] def process(self, p, is_enabled, seed): From 90ace756632af8e423dbb3e69bb21a05a2a7043d Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 9 Jun 2024 17:59:41 +0900 Subject: [PATCH 06/12] disable ext on apply infotext --- scripts/expansion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/expansion.py b/scripts/expansion.py index efc1ffe..5be3515 100644 --- a/scripts/expansion.py +++ b/scripts/expansion.py @@ -234,8 +234,7 @@ def createPositive(positive, seed): class FooocusPromptExpansion(scripts.Script): - def __init__(self) -> None: - super().__init__() + infotext_fields = [] def title(self): return 'Fooocus Prompt Expansion' @@ -246,6 +245,7 @@ class FooocusPromptExpansion(scripts.Script): def ui(self, is_img2img): with InputAccordion(False, label="Fooocus Expansion") as is_enabled: seed = gr.Number(value=0, maximum=63, label="Seed", info="Seed for random number generator") + self.infotext_fields.append((is_enabled, lambda d: False)) return [is_enabled, seed] def process(self, p, is_enabled, seed): From 3b7cb8d8b30c22d7ae3938f270f5a97e1642b365 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 9 Jun 2024 18:12:36 +0900 Subject: [PATCH 07/12] lru_cache create_positive --- scripts/expansion.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/scripts/expansion.py b/scripts/expansion.py index 5be3515..52918d9 100644 --- a/scripts/expansion.py +++ b/scripts/expansion.py @@ -20,6 +20,7 @@ from transformers.generation.logits_process import LogitsProcessorList from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed from modules import scripts, paths_internal, errors, shared, script_callbacks from modules.ui_components import InputAccordion +from functools import lru_cache def text_encoder_device(): @@ -223,14 +224,12 @@ class FooocusExpansion: return result -def createPositive(positive, seed): - try: - expansion = FooocusExpansion() - positive = expansion(positive, seed=seed) - expansion.unload_model() # Unload the model after use - return positive - except Exception as e: - print(f"An error occurred: {str(e)}") +@lru_cache(maxsize=1024) +def create_positive(positive, seed): + expansion = FooocusExpansion() + positive = expansion(positive, seed=seed) + expansion.unload_model() # Unload the model after use + return positive class FooocusPromptExpansion(scripts.Script): @@ -253,7 +252,7 @@ class FooocusPromptExpansion(scripts.Script): return for i, prompt in enumerate(p.all_prompts): - positivePrompt = createPositive(prompt, seed) + positivePrompt = create_positive(prompt, seed) p.all_prompts[i] = positivePrompt def after_component(self, component, **kwargs): From 211e563e7160d679f98351a173edf86feaebb14b Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 9 Jun 2024 18:36:51 +0900 Subject: [PATCH 08/12] simplify device --- scripts/expansion.py | 87 ++------------------------------------------ 1 file changed, 3 insertions(+), 84 deletions(-) diff --git a/scripts/expansion.py b/scripts/expansion.py index 52918d9..a632cf1 100644 --- a/scripts/expansion.py +++ b/scripts/expansion.py @@ -18,46 +18,11 @@ from modules.scripts import basedir from huggingface_hub import hf_hub_download from transformers.generation.logits_process import LogitsProcessorList from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed -from modules import scripts, paths_internal, errors, shared, script_callbacks +from modules import scripts, paths_internal, errors, devices, shared, script_callbacks from modules.ui_components import InputAccordion from functools import lru_cache -def text_encoder_device(): - if torch.cuda.is_available(): - return torch.device(torch.cuda.current_device()) - else: - return torch.device("cpu") - - -def text_encoder_offload_device(): - if torch.cuda.is_available(): - return torch.device(torch.cuda.current_device()) - else: - return torch.device("cpu") - - -def get_free_memory(dev=None, torch_free_too=False): - global directml_enabled - if dev is None: - dev = text_encoder_device() - - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): - mem_free_total = psutil.virtual_memory().available - mem_free_torch = mem_free_total - else: - if directml_enabled: - mem_free_total = 1024 * 1024 * 1024 # TODO - mem_free_torch = mem_free_total - else: - stats = torch.cuda.memory_stats(dev) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(dev) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - # limitation of np.random.seed(), called from transformers.set_seed() SEED_LIMIT_NUMPY = 2 ** 32 neg_inf = - 8192.0 @@ -91,44 +56,6 @@ def remove_pattern(x, pattern): return x -def should_use_fp16(device=None, model_params=0, prioritize_performance=True): - if device is not None: - if hasattr(device, 'type'): - if device.type == 'cpu': - return False - return False - if torch.cuda.is_bf16_supported(): - return True - props = torch.cuda.get_device_properties("cuda") - if props.major < 6: - return False - - fp16_works = False - # FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled - # when the model doesn't actually fit on the card - # TODO: actually test if GP106 and others have the same type of behavior - nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050"] - for x in nvidia_10_series: - if x in props.name.lower(): - fp16_works = True - - if fp16_works: - free_model_memory = (get_free_memory() * 0.9 - (1024 * 1024 * 1024)) - if (not prioritize_performance) or model_params * 4 > free_model_memory: - return True - - if props.major < 7: - return False - - # FP16 is just broken on these cards - nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"] - for x in nvidia_16_series: - if x in props.name: - return False - - return True - - def is_device_mps(device): if hasattr(device, 'type'): if (device.type == 'mps'): @@ -160,16 +87,8 @@ class FooocusExpansion: self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_model_dir) self.model.eval() - load_model_device = text_encoder_device() - offload_device = text_encoder_offload_device() - - # MPS hack - if is_device_mps(load_model_device): - load_model_device = torch.device('cpu') - offload_device = torch.device('cpu') - - use_fp16 = should_use_fp16(device=load_model_device) - + load_model_device = devices.get_optimal_device_name() + use_fp16 = devices.dtype == torch.float16 if use_fp16: self.model.half() From 68b1975a5bfd4929cd1f9e8cbd1b73280efff391 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 9 Jun 2024 19:07:32 +0900 Subject: [PATCH 09/12] Implement preview UI --- scripts/expansion.py | 40 +++++++++++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/scripts/expansion.py b/scripts/expansion.py index a632cf1..05ac16e 100644 --- a/scripts/expansion.py +++ b/scripts/expansion.py @@ -153,6 +153,14 @@ def create_positive(positive, seed): class FooocusPromptExpansion(scripts.Script): infotext_fields = [] + prompt_elm = None + + def __init__(self): + super().__init__() + self.on_after_component_elem_id = [ + ('txt2img_prompt', self.save_prompt_box), + ('img2img_prompt', self.save_prompt_box), + ] def title(self): return 'Fooocus Prompt Expansion' @@ -162,8 +170,28 @@ class FooocusPromptExpansion(scripts.Script): def ui(self, is_img2img): with InputAccordion(False, label="Fooocus Expansion") as is_enabled: - seed = gr.Number(value=0, maximum=63, label="Seed", info="Seed for random number generator") + seed = gr.Number(value=0, label="Seed", info="Seed for random number generator") + if self.prompt_elm is not None: + with gr.Row(): + generate = gr.Button('Generate expansion prompts') + apply = gr.Button('Apply expansion to prompts') + preview = gr.Textbox('', label="Expansion preview", interactive=False) + + for x in [preview, generate, apply]: + x.save_to_config = False + + generate.click( + fn=create_positive, + inputs=[self.prompt_elm, seed], + outputs=[preview], + ) + apply.click( + fn=lambda *args: (False, create_positive(*args)), + inputs=[self.prompt_elm, seed], + outputs=[is_enabled, self.prompt_elm], + ) self.infotext_fields.append((is_enabled, lambda d: False)) + return [is_enabled, seed] def process(self, p, is_enabled, seed): @@ -174,11 +202,5 @@ class FooocusPromptExpansion(scripts.Script): positivePrompt = create_positive(prompt, seed) p.all_prompts[i] = positivePrompt - def after_component(self, component, **kwargs): - # https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/7456#issuecomment-1414465888 helpfull link - # Find the text2img textbox component - if kwargs.get("elem_id") == "txt2img_prompt": # postive prompt textbox - self.boxx = component - # Find the img2img textbox component - if kwargs.get("elem_id") == "img2img_prompt": # postive prompt textbox - self.boxxIMG = component + def save_prompt_box(self, on_component): + self.prompt_elm = on_component.component From 22e7b1ced3fceaf53c51be665bdf8e171f3cc9f0 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 9 Jun 2024 19:21:15 +0900 Subject: [PATCH 10/12] simplifies and remove unused code --- scripts/expansion.py | 44 ++++++++++++++------------------------------ 1 file changed, 14 insertions(+), 30 deletions(-) diff --git a/scripts/expansion.py b/scripts/expansion.py index 05ac16e..6d84aa5 100644 --- a/scripts/expansion.py +++ b/scripts/expansion.py @@ -7,18 +7,18 @@ import os +import re import torch import math import shutil import gradio as gr -import psutil from pathlib import Path from modules.scripts import basedir from huggingface_hub import hf_hub_download from transformers.generation.logits_process import LogitsProcessorList from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed -from modules import scripts, paths_internal, errors, devices, shared, script_callbacks +from modules import scripts, paths_internal, errors, devices from modules.ui_components import InputAccordion from functools import lru_cache @@ -44,28 +44,12 @@ def download_model(): def safe_str(x): - x = str(x) - for _ in range(16): - x = x.replace(' ', ' ') - return x.strip(",. \r\n") - - -def remove_pattern(x, pattern): - for p in pattern: - x = x.replace(p, '') - return x - - -def is_device_mps(device): - if hasattr(device, 'type'): - if (device.type == 'mps'): - return True - return False + return re.sub(r' +', r' ', x).strip(",. \r\n") class FooocusExpansion: def __init__(self): - global load_model_device + download_model() print(f'Loading models from {fooocus_expansion_model_dir}') self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_model_dir) @@ -87,14 +71,14 @@ class FooocusExpansion: self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_model_dir) self.model.eval() - load_model_device = devices.get_optimal_device_name() + self.load_model_device = devices.get_optimal_device_name() use_fp16 = devices.dtype == torch.float16 if use_fp16: self.model.half() - self.model.to(load_model_device) # Ensure model is on the correct device + self.model.to(self.load_model_device) # Ensure the model is on the correct device - print(f'Fooocus Expansion engine loaded for {load_model_device}, use_fp16 = {use_fp16}.') + print(f'Fooocus Expansion engine loaded for {self.load_model_device}, use_fp16 = {use_fp16}.') def unload_model(self): """Unload the model to free up memory.""" @@ -106,10 +90,10 @@ class FooocusExpansion: @torch.inference_mode() def logits_processor(self, input_ids, scores): assert scores.ndim == 2 and scores.shape[0] == 1 - self.logits_bias = self.logits_bias.to(load_model_device) + self.logits_bias = self.logits_bias.to(self.load_model_device) - bias = self.logits_bias.clone().to(load_model_device) # Ensure bias is on the correct device - bias[0, input_ids[0].to(load_model_device).long()] = neg_inf # Ensure input_ids are on the correct device + bias = self.logits_bias.clone().to(self.load_model_device) # Ensure bias is on the correct device + bias[0, input_ids[0].to(self.load_model_device).long()] = neg_inf # Ensure input_ids are on the correct device bias[0, 11] = 0 return scores + bias.to(scores.device) # Ensure bias is on the same device as scores @@ -124,8 +108,8 @@ class FooocusExpansion: set_seed(seed) prompt = safe_str(prompt) + ',' tokenized_kwargs = self.tokenizer(prompt, return_tensors="pt") - tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(load_model_device) - tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(load_model_device) + tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(self.load_model_device) + tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(self.load_model_device) current_token_length = int(tokenized_kwargs.data['input_ids'].shape[1]) max_token_length = 75 * int(math.ceil(float(current_token_length) / 75.0)) @@ -199,8 +183,8 @@ class FooocusPromptExpansion(scripts.Script): return for i, prompt in enumerate(p.all_prompts): - positivePrompt = create_positive(prompt, seed) - p.all_prompts[i] = positivePrompt + positive_prompt = create_positive(prompt, seed) + p.all_prompts[i] = positive_prompt def save_prompt_box(self, on_component): self.prompt_elm = on_component.component From 14dbf81da4b56638024a28d69b92e3ea1bdd5708 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 9 Jun 2024 19:29:38 +0900 Subject: [PATCH 11/12] add empyt prompt check in create_positive --- scripts/expansion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/expansion.py b/scripts/expansion.py index 6d84aa5..8004a57 100644 --- a/scripts/expansion.py +++ b/scripts/expansion.py @@ -101,7 +101,7 @@ class FooocusExpansion: @torch.no_grad() @torch.inference_mode() def __call__(self, prompt, seed): - if prompt == '': + if not prompt: return '' seed = int(seed) % SEED_LIMIT_NUMPY @@ -129,6 +129,8 @@ class FooocusExpansion: @lru_cache(maxsize=1024) def create_positive(positive, seed): + if not positive: + return '' expansion = FooocusExpansion() positive = expansion(positive, seed=seed) expansion.unload_model() # Unload the model after use From 4f2a4488a4813332c76c24d1a5dac475f23aa587 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 9 Jun 2024 19:43:23 +0900 Subject: [PATCH 12/12] Update expansion.py --- scripts/expansion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scripts/expansion.py b/scripts/expansion.py index 8004a57..e8cfbb9 100644 --- a/scripts/expansion.py +++ b/scripts/expansion.py @@ -185,8 +185,7 @@ class FooocusPromptExpansion(scripts.Script): return for i, prompt in enumerate(p.all_prompts): - positive_prompt = create_positive(prompt, seed) - p.all_prompts[i] = positive_prompt + p.all_prompts[i] = create_positive(prompt, seed) def save_prompt_box(self, on_component): self.prompt_elm = on_component.component