diff --git a/scripts/postprocessing_rembg.py b/scripts/postprocessing_rembg.py index 8d64332..9e8eaee 100644 --- a/scripts/postprocessing_rembg.py +++ b/scripts/postprocessing_rembg.py @@ -2,7 +2,9 @@ from modules import scripts_postprocessing import gradio as gr from modules.ui_components import FormRow +from modules.paths_internal import models_path import rembg +import os models = [ "None", @@ -50,6 +52,9 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): if not model or model == "None": return + if "U2NET_HOME" not in os.environ: + os.environ["U2NET_HOME"] = os.path.join(models_path, "u2net") + pp.image = rembg.remove( pp.image, session=rembg.new_session(model),