simplifies and remove unused code

This commit is contained in:
w-e-w
2024-06-09 19:21:15 +09:00
parent 68b1975a5b
commit 22e7b1ced3

View File

@@ -7,18 +7,18 @@
import os
import re
import torch
import math
import shutil
import gradio as gr
import psutil
from pathlib import Path
from modules.scripts import basedir
from huggingface_hub import hf_hub_download
from transformers.generation.logits_process import LogitsProcessorList
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from modules import scripts, paths_internal, errors, devices, shared, script_callbacks
from modules import scripts, paths_internal, errors, devices
from modules.ui_components import InputAccordion
from functools import lru_cache
@@ -44,28 +44,12 @@ def download_model():
def safe_str(x):
x = str(x)
for _ in range(16):
x = x.replace(' ', ' ')
return x.strip(",. \r\n")
def remove_pattern(x, pattern):
for p in pattern:
x = x.replace(p, '')
return x
def is_device_mps(device):
if hasattr(device, 'type'):
if (device.type == 'mps'):
return True
return False
return re.sub(r' +', r' ', x).strip(",. \r\n")
class FooocusExpansion:
def __init__(self):
global load_model_device
download_model()
print(f'Loading models from {fooocus_expansion_model_dir}')
self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_model_dir)
@@ -87,14 +71,14 @@ class FooocusExpansion:
self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_model_dir)
self.model.eval()
load_model_device = devices.get_optimal_device_name()
self.load_model_device = devices.get_optimal_device_name()
use_fp16 = devices.dtype == torch.float16
if use_fp16:
self.model.half()
self.model.to(load_model_device) # Ensure model is on the correct device
self.model.to(self.load_model_device) # Ensure the model is on the correct device
print(f'Fooocus Expansion engine loaded for {load_model_device}, use_fp16 = {use_fp16}.')
print(f'Fooocus Expansion engine loaded for {self.load_model_device}, use_fp16 = {use_fp16}.')
def unload_model(self):
"""Unload the model to free up memory."""
@@ -106,10 +90,10 @@ class FooocusExpansion:
@torch.inference_mode()
def logits_processor(self, input_ids, scores):
assert scores.ndim == 2 and scores.shape[0] == 1
self.logits_bias = self.logits_bias.to(load_model_device)
self.logits_bias = self.logits_bias.to(self.load_model_device)
bias = self.logits_bias.clone().to(load_model_device) # Ensure bias is on the correct device
bias[0, input_ids[0].to(load_model_device).long()] = neg_inf # Ensure input_ids are on the correct device
bias = self.logits_bias.clone().to(self.load_model_device) # Ensure bias is on the correct device
bias[0, input_ids[0].to(self.load_model_device).long()] = neg_inf # Ensure input_ids are on the correct device
bias[0, 11] = 0
return scores + bias.to(scores.device) # Ensure bias is on the same device as scores
@@ -124,8 +108,8 @@ class FooocusExpansion:
set_seed(seed)
prompt = safe_str(prompt) + ','
tokenized_kwargs = self.tokenizer(prompt, return_tensors="pt")
tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(load_model_device)
tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(load_model_device)
tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(self.load_model_device)
tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(self.load_model_device)
current_token_length = int(tokenized_kwargs.data['input_ids'].shape[1])
max_token_length = 75 * int(math.ceil(float(current_token_length) / 75.0))
@@ -199,8 +183,8 @@ class FooocusPromptExpansion(scripts.Script):
return
for i, prompt in enumerate(p.all_prompts):
positivePrompt = create_positive(prompt, seed)
p.all_prompts[i] = positivePrompt
positive_prompt = create_positive(prompt, seed)
p.all_prompts[i] = positive_prompt
def save_prompt_box(self, on_component):
self.prompt_elm = on_component.component