diff --git a/scripts/postprocessing_rembg.py b/scripts/postprocessing_rembg.py index c293a5c..1b9c09e 100644 --- a/scripts/postprocessing_rembg.py +++ b/scripts/postprocessing_rembg.py @@ -2,7 +2,9 @@ from modules import scripts_postprocessing, ui_components import gradio as gr from modules.ui_components import FormRow +from modules.paths_internal import models_path import rembg +import os models = [ "None", @@ -55,6 +57,9 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): if not model or model == "None": return + if "U2NET_HOME" not in os.environ: + os.environ["U2NET_HOME"] = os.path.join(models_path, "u2net") + pp.image = rembg.remove( pp.image, session=rembg.new_session(model),