mirror of
https://github.com/power88/webui-fooocus-prompt-expansion.git
synced 2026-04-24 00:08:55 +00:00
simplifies and remove unused code
This commit is contained in:
@@ -7,18 +7,18 @@
|
||||
|
||||
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import math
|
||||
import shutil
|
||||
import gradio as gr
|
||||
import psutil
|
||||
|
||||
from pathlib import Path
|
||||
from modules.scripts import basedir
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers.generation.logits_process import LogitsProcessorList
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
|
||||
from modules import scripts, paths_internal, errors, devices, shared, script_callbacks
|
||||
from modules import scripts, paths_internal, errors, devices
|
||||
from modules.ui_components import InputAccordion
|
||||
from functools import lru_cache
|
||||
|
||||
@@ -44,28 +44,12 @@ def download_model():
|
||||
|
||||
|
||||
def safe_str(x):
|
||||
x = str(x)
|
||||
for _ in range(16):
|
||||
x = x.replace(' ', ' ')
|
||||
return x.strip(",. \r\n")
|
||||
|
||||
|
||||
def remove_pattern(x, pattern):
|
||||
for p in pattern:
|
||||
x = x.replace(p, '')
|
||||
return x
|
||||
|
||||
|
||||
def is_device_mps(device):
|
||||
if hasattr(device, 'type'):
|
||||
if (device.type == 'mps'):
|
||||
return True
|
||||
return False
|
||||
return re.sub(r' +', r' ', x).strip(",. \r\n")
|
||||
|
||||
|
||||
class FooocusExpansion:
|
||||
def __init__(self):
|
||||
global load_model_device
|
||||
|
||||
download_model()
|
||||
print(f'Loading models from {fooocus_expansion_model_dir}')
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_model_dir)
|
||||
@@ -87,14 +71,14 @@ class FooocusExpansion:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_model_dir)
|
||||
self.model.eval()
|
||||
|
||||
load_model_device = devices.get_optimal_device_name()
|
||||
self.load_model_device = devices.get_optimal_device_name()
|
||||
use_fp16 = devices.dtype == torch.float16
|
||||
if use_fp16:
|
||||
self.model.half()
|
||||
|
||||
self.model.to(load_model_device) # Ensure model is on the correct device
|
||||
self.model.to(self.load_model_device) # Ensure the model is on the correct device
|
||||
|
||||
print(f'Fooocus Expansion engine loaded for {load_model_device}, use_fp16 = {use_fp16}.')
|
||||
print(f'Fooocus Expansion engine loaded for {self.load_model_device}, use_fp16 = {use_fp16}.')
|
||||
|
||||
def unload_model(self):
|
||||
"""Unload the model to free up memory."""
|
||||
@@ -106,10 +90,10 @@ class FooocusExpansion:
|
||||
@torch.inference_mode()
|
||||
def logits_processor(self, input_ids, scores):
|
||||
assert scores.ndim == 2 and scores.shape[0] == 1
|
||||
self.logits_bias = self.logits_bias.to(load_model_device)
|
||||
self.logits_bias = self.logits_bias.to(self.load_model_device)
|
||||
|
||||
bias = self.logits_bias.clone().to(load_model_device) # Ensure bias is on the correct device
|
||||
bias[0, input_ids[0].to(load_model_device).long()] = neg_inf # Ensure input_ids are on the correct device
|
||||
bias = self.logits_bias.clone().to(self.load_model_device) # Ensure bias is on the correct device
|
||||
bias[0, input_ids[0].to(self.load_model_device).long()] = neg_inf # Ensure input_ids are on the correct device
|
||||
bias[0, 11] = 0
|
||||
|
||||
return scores + bias.to(scores.device) # Ensure bias is on the same device as scores
|
||||
@@ -124,8 +108,8 @@ class FooocusExpansion:
|
||||
set_seed(seed)
|
||||
prompt = safe_str(prompt) + ','
|
||||
tokenized_kwargs = self.tokenizer(prompt, return_tensors="pt")
|
||||
tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(load_model_device)
|
||||
tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(load_model_device)
|
||||
tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(self.load_model_device)
|
||||
tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(self.load_model_device)
|
||||
|
||||
current_token_length = int(tokenized_kwargs.data['input_ids'].shape[1])
|
||||
max_token_length = 75 * int(math.ceil(float(current_token_length) / 75.0))
|
||||
@@ -199,8 +183,8 @@ class FooocusPromptExpansion(scripts.Script):
|
||||
return
|
||||
|
||||
for i, prompt in enumerate(p.all_prompts):
|
||||
positivePrompt = create_positive(prompt, seed)
|
||||
p.all_prompts[i] = positivePrompt
|
||||
positive_prompt = create_positive(prompt, seed)
|
||||
p.all_prompts[i] = positive_prompt
|
||||
|
||||
def save_prompt_box(self, on_component):
|
||||
self.prompt_elm = on_component.component
|
||||
|
||||
Reference in New Issue
Block a user