mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-03-13 09:20:09 +00:00
feat: ensure args dtype
This commit is contained in:
@@ -29,24 +29,24 @@ print(
|
||||
f"[-] ADetailer initialized. version: {__version__}, num models: {len(model_mapping)}"
|
||||
)
|
||||
|
||||
all_args = [
|
||||
("ad_model", "ADetailer model"),
|
||||
("ad_prompt", "ADetailer prompt"),
|
||||
("ad_negative_prompt", "ADetailer negative prompt"),
|
||||
("ad_conf", "ADetailer conf"),
|
||||
("ad_dilate_erode", "ADetailer dilate/erode"),
|
||||
("ad_x_offset", "ADetailer x offset"),
|
||||
("ad_y_offset", "ADetailer y offset"),
|
||||
("ad_mask_blur", "ADetailer mask blur"),
|
||||
("ad_denoising_strength", "ADetailer denoising strength"),
|
||||
("ad_inpaint_full_res", "ADetailer inpaint full"),
|
||||
("ad_inpaint_full_res_padding", "ADetailer inpaint padding"),
|
||||
("ad_use_inpaint_width_height", "ADetailer use inpaint width/height"),
|
||||
("ad_inpaint_width", "ADetailer inpaint width"),
|
||||
("ad_inpaint_height", "ADetailer inpaint height"),
|
||||
("ad_cfg_scale", "ADetailer CFG scale"),
|
||||
("ad_controlnet_model", "ADetailer ControlNet model"),
|
||||
("ad_controlnet_weight", "ADetailer ControlNet weight"),
|
||||
ALL_ARGS = [
|
||||
("ad_model", "ADetailer model", str),
|
||||
("ad_prompt", "ADetailer prompt", str),
|
||||
("ad_negative_prompt", "ADetailer negative prompt", str),
|
||||
("ad_conf", "ADetailer conf", int),
|
||||
("ad_dilate_erode", "ADetailer dilate/erode", int),
|
||||
("ad_x_offset", "ADetailer x offset", int),
|
||||
("ad_y_offset", "ADetailer y offset", int),
|
||||
("ad_mask_blur", "ADetailer mask blur", int),
|
||||
("ad_denoising_strength", "ADetailer denoising strength", float),
|
||||
("ad_inpaint_full_res", "ADetailer inpaint full", bool),
|
||||
("ad_inpaint_full_res_padding", "ADetailer inpaint padding", int),
|
||||
("ad_use_inpaint_width_height", "ADetailer use inpaint width/height", bool),
|
||||
("ad_inpaint_width", "ADetailer inpaint width", int),
|
||||
("ad_inpaint_height", "ADetailer inpaint height", int),
|
||||
("ad_cfg_scale", "ADetailer CFG scale", float),
|
||||
("ad_controlnet_model", "ADetailer ControlNet model", str),
|
||||
("ad_controlnet_weight", "ADetailer ControlNet weight", float),
|
||||
]
|
||||
|
||||
|
||||
@@ -70,8 +70,9 @@ class ADetailerArgs:
|
||||
ad_controlnet_weight: float
|
||||
|
||||
def __init__(self, *args):
|
||||
for i, (attr, _) in enumerate(all_args):
|
||||
if i == 3: # ad_conf
|
||||
args = self.ensure_dtype(args)
|
||||
for i, (attr, *_) in enumerate(ALL_ARGS):
|
||||
if attr == "ad_conf":
|
||||
setattr(self, attr, args[i] / 100.0)
|
||||
else:
|
||||
setattr(self, attr, args[i])
|
||||
@@ -79,10 +80,21 @@ class ADetailerArgs:
|
||||
def asdict(self):
|
||||
return self.__dict__
|
||||
|
||||
def ensure_dtype(self, args):
|
||||
args = list(args)
|
||||
for i, (attr, _, dtype) in enumerate(ALL_ARGS):
|
||||
if not isinstance(args[i], dtype):
|
||||
try:
|
||||
args[i] = dtype(args[i])
|
||||
except ValueError as e:
|
||||
msg = f"Error converting {attr!r} to {dtype}: {e}"
|
||||
raise ValueError(msg) from e
|
||||
return args
|
||||
|
||||
|
||||
class Widgets:
|
||||
def tolist(self):
|
||||
return [getattr(self, attr) for attr, _ in all_args]
|
||||
return [getattr(self, attr) for attr, *_ in ALL_ARGS]
|
||||
|
||||
|
||||
class ChangeTorchLoad:
|
||||
@@ -102,6 +114,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.controlnet_ext = None
|
||||
self.ultralytics_device = self.get_ultralytics_device()
|
||||
|
||||
def title(self):
|
||||
return AFTER_DETAILER
|
||||
@@ -271,7 +284,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
interactive=controlnet_exists,
|
||||
)
|
||||
|
||||
self.infotext_fields = [(getattr(w, attr), name) for attr, name in all_args]
|
||||
self.infotext_fields = [(getattr(w, attr), name) for attr, name, *_ in ALL_ARGS]
|
||||
|
||||
return w.tolist()
|
||||
|
||||
@@ -283,7 +296,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
print("[-] ADetailer: ControlNetExt init failed.", file=sys.stderr)
|
||||
|
||||
def extra_params(self, **kwargs):
|
||||
params = {name: kwargs[attr] for attr, name in all_args}
|
||||
params = {name: kwargs[attr] for attr, name, *_ in ALL_ARGS}
|
||||
params["ADetailer conf"] = int(params["ADetailer conf"] * 100)
|
||||
|
||||
if not params["ADetailer prompt"]:
|
||||
@@ -470,7 +483,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
else:
|
||||
predictor = ultralytics_predict
|
||||
ad_model = model_mapping[args.ad_model]
|
||||
kwargs["device"] = self.get_ultralytics_device()
|
||||
kwargs["device"] = self.ultralytics_device
|
||||
|
||||
with ChangeTorchLoad():
|
||||
pred = predictor(ad_model, pp.image, args.ad_conf, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user