diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index 39d269f..f74149d 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -29,24 +29,24 @@ print( f"[-] ADetailer initialized. version: {__version__}, num models: {len(model_mapping)}" ) -all_args = [ - ("ad_model", "ADetailer model"), - ("ad_prompt", "ADetailer prompt"), - ("ad_negative_prompt", "ADetailer negative prompt"), - ("ad_conf", "ADetailer conf"), - ("ad_dilate_erode", "ADetailer dilate/erode"), - ("ad_x_offset", "ADetailer x offset"), - ("ad_y_offset", "ADetailer y offset"), - ("ad_mask_blur", "ADetailer mask blur"), - ("ad_denoising_strength", "ADetailer denoising strength"), - ("ad_inpaint_full_res", "ADetailer inpaint full"), - ("ad_inpaint_full_res_padding", "ADetailer inpaint padding"), - ("ad_use_inpaint_width_height", "ADetailer use inpaint width/height"), - ("ad_inpaint_width", "ADetailer inpaint width"), - ("ad_inpaint_height", "ADetailer inpaint height"), - ("ad_cfg_scale", "ADetailer CFG scale"), - ("ad_controlnet_model", "ADetailer ControlNet model"), - ("ad_controlnet_weight", "ADetailer ControlNet weight"), +ALL_ARGS = [ + ("ad_model", "ADetailer model", str), + ("ad_prompt", "ADetailer prompt", str), + ("ad_negative_prompt", "ADetailer negative prompt", str), + ("ad_conf", "ADetailer conf", int), + ("ad_dilate_erode", "ADetailer dilate/erode", int), + ("ad_x_offset", "ADetailer x offset", int), + ("ad_y_offset", "ADetailer y offset", int), + ("ad_mask_blur", "ADetailer mask blur", int), + ("ad_denoising_strength", "ADetailer denoising strength", float), + ("ad_inpaint_full_res", "ADetailer inpaint full", bool), + ("ad_inpaint_full_res_padding", "ADetailer inpaint padding", int), + ("ad_use_inpaint_width_height", "ADetailer use inpaint width/height", bool), + ("ad_inpaint_width", "ADetailer inpaint width", int), + ("ad_inpaint_height", "ADetailer inpaint height", int), + ("ad_cfg_scale", "ADetailer CFG scale", float), + ("ad_controlnet_model", "ADetailer ControlNet model", str), + ("ad_controlnet_weight", "ADetailer ControlNet weight", float), ] @@ -70,8 +70,9 @@ class ADetailerArgs: ad_controlnet_weight: float def __init__(self, *args): - for i, (attr, _) in enumerate(all_args): - if i == 3: # ad_conf + args = self.ensure_dtype(args) + for i, (attr, *_) in enumerate(ALL_ARGS): + if attr == "ad_conf": setattr(self, attr, args[i] / 100.0) else: setattr(self, attr, args[i]) @@ -79,10 +80,21 @@ class ADetailerArgs: def asdict(self): return self.__dict__ + def ensure_dtype(self, args): + args = list(args) + for i, (attr, _, dtype) in enumerate(ALL_ARGS): + if not isinstance(args[i], dtype): + try: + args[i] = dtype(args[i]) + except ValueError as e: + msg = f"Error converting {attr!r} to {dtype}: {e}" + raise ValueError(msg) from e + return args + class Widgets: def tolist(self): - return [getattr(self, attr) for attr, _ in all_args] + return [getattr(self, attr) for attr, *_ in ALL_ARGS] class ChangeTorchLoad: @@ -102,6 +114,7 @@ class AfterDetailerScript(scripts.Script): def __init__(self): super().__init__() self.controlnet_ext = None + self.ultralytics_device = self.get_ultralytics_device() def title(self): return AFTER_DETAILER @@ -271,7 +284,7 @@ class AfterDetailerScript(scripts.Script): interactive=controlnet_exists, ) - self.infotext_fields = [(getattr(w, attr), name) for attr, name in all_args] + self.infotext_fields = [(getattr(w, attr), name) for attr, name, *_ in ALL_ARGS] return w.tolist() @@ -283,7 +296,7 @@ class AfterDetailerScript(scripts.Script): print("[-] ADetailer: ControlNetExt init failed.", file=sys.stderr) def extra_params(self, **kwargs): - params = {name: kwargs[attr] for attr, name in all_args} + params = {name: kwargs[attr] for attr, name, *_ in ALL_ARGS} params["ADetailer conf"] = int(params["ADetailer conf"] * 100) if not params["ADetailer prompt"]: @@ -470,7 +483,7 @@ class AfterDetailerScript(scripts.Script): else: predictor = ultralytics_predict ad_model = model_mapping[args.ad_model] - kwargs["device"] = self.get_ultralytics_device() + kwargs["device"] = self.ultralytics_device with ChangeTorchLoad(): pred = predictor(ad_model, pp.image, args.ad_conf, **kwargs)