mirror of
https://github.com/power88/webui-fooocus-prompt-expansion.git
synced 2026-05-01 03:31:16 +00:00
add unload_model() to free memory
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
# Fooocus GPT2 Expansion
|
# Fooocus GPT2 Expansion
|
||||||
# Algorithm created by Lvmin Zhang at 2023, Stanford
|
# Algorithm created by Lvmin Zhang at 2023, Stanford
|
||||||
# modified by PlayDystinDB and GPT-4O for stable-diffusion-webui
|
# Modified by power88 and GPT-4o for stable-diffusion-webui
|
||||||
# If used inside Fooocus, any use is permitted.
|
# If used inside Fooocus, any use is permitted.
|
||||||
# If used outside Fooocus, only non-commercial use is permitted (CC-By NC 4.0).
|
# If used outside Fooocus, only non-commercial use is permitted (CC-By NC 4.0).
|
||||||
# This applies to the word list, vocab, model, and algorithm.
|
# This applies to the word list, vocab, model, and algorithm.
|
||||||
@@ -155,6 +155,12 @@ class FooocusExpansion:
|
|||||||
|
|
||||||
print(f'Fooocus Expansion engine loaded for {load_model_device}, use_fp16 = {use_fp16}.')
|
print(f'Fooocus Expansion engine loaded for {load_model_device}, use_fp16 = {use_fp16}.')
|
||||||
|
|
||||||
|
def unload_model(self):
|
||||||
|
"""Unload the model to free up memory."""
|
||||||
|
del self.model
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print('Model unloaded and memory cleared.')
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def logits_processor(self, input_ids, scores):
|
def logits_processor(self, input_ids, scores):
|
||||||
@@ -199,6 +205,7 @@ def createPositive(positive, seed):
|
|||||||
try:
|
try:
|
||||||
expansion = FooocusExpansion()
|
expansion = FooocusExpansion()
|
||||||
positive = expansion(positive, seed=seed)
|
positive = expansion(positive, seed=seed)
|
||||||
|
expansion.unload_model() # Unload the model after use
|
||||||
return positive
|
return positive
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred: {str(e)}")
|
print(f"An error occurred: {str(e)}")
|
||||||
@@ -208,7 +215,7 @@ class FooocusPromptExpansion(scripts.Script):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def title(self):
|
def title(self):
|
||||||
return 'Fooocus Expansion'
|
return 'Fooocus Prompt Expansion'
|
||||||
|
|
||||||
def show(self, is_img2img):
|
def show(self, is_img2img):
|
||||||
return scripts.AlwaysVisible
|
return scripts.AlwaysVisible
|
||||||
|
|||||||
Reference in New Issue
Block a user