download the model on first use

remove install.py
This commit is contained in:
w-e-w
2024-06-09 17:48:22 +09:00
parent 979e5e69da
commit 483fdc4330
2 changed files with 24 additions and 30 deletions

View File

@@ -1,23 +0,0 @@
import os
import pathlib
import shutil
from huggingface_hub import hf_hub_download
from modules.scripts import basedir
ext_dir = basedir()
fooocus_expansion_path = pathlib.Path(ext_dir) / "models" / "prompt_expansion"
base_model_path = pathlib.Path(ext_dir) / "extensions" / "webui-fooocus-prompt-expansion" / "models"
if not os.path.exists(os.path.join(fooocus_expansion_path, 'pytorch_model.bin')):
try:
print(f'### webui-fooocus-prompt-expansion: Downloading model...')
shutil.copytree(os.path.join(base_model_path), fooocus_expansion_path)
hf_hub_download(repo_id='lllyasviel/misc', filename='fooocus_expansion.bin', local_dir=os.path.join(fooocus_expansion_path), resume_download=True, local_dir_use_symlinks=False)
os.rename(os.path.join(fooocus_expansion_path, 'fooocus_expansion.bin'), os.path.join(fooocus_expansion_path, 'pytorch_model.bin'))
except Exception as e:
print(f'### webui-fooocus-prompt-expansion: Failed to download model...')
print(e)
print(f'### webui-fooocus-prompt-expansion: To enable this custom node, please download the model manually from "https://huggingface.co/lllyasviel/misc/tree/main/fooocus_expansion.bin" and place it in {fooocus_expansion_path}.')
else:
pass

View File

@@ -9,13 +9,16 @@
import os
import torch
import math
import shutil
import gradio as gr
import psutil
from pathlib import Path
from modules.scripts import basedir
from huggingface_hub import hf_hub_download
from transformers.generation.logits_process import LogitsProcessorList
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from modules import scripts, paths_internal, shared, script_callbacks
from modules import scripts, paths_internal, errors, shared, script_callbacks
from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton
@@ -57,8 +60,21 @@ def get_free_memory(dev=None, torch_free_too=False):
# limitation of np.random.seed(), called from transformers.set_seed()
SEED_LIMIT_NUMPY = 2 ** 32
neg_inf = - 8192.0
ext_dir = basedir()
path_fooocus_expansion = os.path.join(paths_internal.models_path, "prompt_expansion")
ext_dir = Path(basedir())
fooocus_expansion_model_dir = Path(paths_internal.models_path) / "prompt_expansion"
def download_model():
fooocus_expansion_model = fooocus_expansion_model_dir / "pytorch_model.bin"
if not fooocus_expansion_model.exists():
try:
print(f'### webui-fooocus-prompt-expansion: Downloading model...')
shutil.copytree(ext_dir / "models", fooocus_expansion_model_dir)
hf_hub_download(repo_id='lllyasviel/misc', filename='fooocus_expansion.bin', local_dir=fooocus_expansion_model_dir)
os.rename(fooocus_expansion_model_dir / 'fooocus_expansion.bin', fooocus_expansion_model)
except Exception:
errors.report('### webui-fooocus-prompt-expansion: Failed to download model', exc_info=True)
print(f'Download the model manually from "https://huggingface.co/lllyasviel/misc/tree/main/fooocus_expansion.bin" and place it in {fooocus_expansion_model_dir}.')
def safe_str(x):
@@ -122,10 +138,11 @@ def is_device_mps(device):
class FooocusExpansion:
def __init__(self):
global load_model_device
print(f'Loading models from {path_fooocus_expansion}')
self.tokenizer = AutoTokenizer.from_pretrained(path_fooocus_expansion)
download_model()
print(f'Loading models from {fooocus_expansion_model_dir}')
self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_model_dir)
positive_words = open(os.path.join(path_fooocus_expansion, 'positive.txt'),
positive_words = open(os.path.join(fooocus_expansion_model_dir, 'positive.txt'),
encoding='utf-8').read().splitlines()
positive_words = ['Ġ' + x.lower() for x in positive_words if x != '']
@@ -139,7 +156,7 @@ class FooocusExpansion:
print(f'Fooocus V2 Expansion: Vocab with {len(debug_list)} words.')
self.model = AutoModelForCausalLM.from_pretrained(path_fooocus_expansion)
self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_model_dir)
self.model.eval()
load_model_device = text_encoder_device()