feat: ensure args dtype

This commit is contained in:
Bingsu
2023-04-30 14:41:26 +09:00
parent ce74bc839e
commit 82377eb916

View File

@@ -29,24 +29,24 @@ print(
f"[-] ADetailer initialized. version: {__version__}, num models: {len(model_mapping)}"
)
all_args = [
("ad_model", "ADetailer model"),
("ad_prompt", "ADetailer prompt"),
("ad_negative_prompt", "ADetailer negative prompt"),
("ad_conf", "ADetailer conf"),
("ad_dilate_erode", "ADetailer dilate/erode"),
("ad_x_offset", "ADetailer x offset"),
("ad_y_offset", "ADetailer y offset"),
("ad_mask_blur", "ADetailer mask blur"),
("ad_denoising_strength", "ADetailer denoising strength"),
("ad_inpaint_full_res", "ADetailer inpaint full"),
("ad_inpaint_full_res_padding", "ADetailer inpaint padding"),
("ad_use_inpaint_width_height", "ADetailer use inpaint width/height"),
("ad_inpaint_width", "ADetailer inpaint width"),
("ad_inpaint_height", "ADetailer inpaint height"),
("ad_cfg_scale", "ADetailer CFG scale"),
("ad_controlnet_model", "ADetailer ControlNet model"),
("ad_controlnet_weight", "ADetailer ControlNet weight"),
ALL_ARGS = [
("ad_model", "ADetailer model", str),
("ad_prompt", "ADetailer prompt", str),
("ad_negative_prompt", "ADetailer negative prompt", str),
("ad_conf", "ADetailer conf", int),
("ad_dilate_erode", "ADetailer dilate/erode", int),
("ad_x_offset", "ADetailer x offset", int),
("ad_y_offset", "ADetailer y offset", int),
("ad_mask_blur", "ADetailer mask blur", int),
("ad_denoising_strength", "ADetailer denoising strength", float),
("ad_inpaint_full_res", "ADetailer inpaint full", bool),
("ad_inpaint_full_res_padding", "ADetailer inpaint padding", int),
("ad_use_inpaint_width_height", "ADetailer use inpaint width/height", bool),
("ad_inpaint_width", "ADetailer inpaint width", int),
("ad_inpaint_height", "ADetailer inpaint height", int),
("ad_cfg_scale", "ADetailer CFG scale", float),
("ad_controlnet_model", "ADetailer ControlNet model", str),
("ad_controlnet_weight", "ADetailer ControlNet weight", float),
]
@@ -70,8 +70,9 @@ class ADetailerArgs:
ad_controlnet_weight: float
def __init__(self, *args):
for i, (attr, _) in enumerate(all_args):
if i == 3: # ad_conf
args = self.ensure_dtype(args)
for i, (attr, *_) in enumerate(ALL_ARGS):
if attr == "ad_conf":
setattr(self, attr, args[i] / 100.0)
else:
setattr(self, attr, args[i])
@@ -79,10 +80,21 @@ class ADetailerArgs:
def asdict(self):
return self.__dict__
def ensure_dtype(self, args):
args = list(args)
for i, (attr, _, dtype) in enumerate(ALL_ARGS):
if not isinstance(args[i], dtype):
try:
args[i] = dtype(args[i])
except ValueError as e:
msg = f"Error converting {attr!r} to {dtype}: {e}"
raise ValueError(msg) from e
return args
class Widgets:
def tolist(self):
return [getattr(self, attr) for attr, _ in all_args]
return [getattr(self, attr) for attr, *_ in ALL_ARGS]
class ChangeTorchLoad:
@@ -102,6 +114,7 @@ class AfterDetailerScript(scripts.Script):
def __init__(self):
super().__init__()
self.controlnet_ext = None
self.ultralytics_device = self.get_ultralytics_device()
def title(self):
return AFTER_DETAILER
@@ -271,7 +284,7 @@ class AfterDetailerScript(scripts.Script):
interactive=controlnet_exists,
)
self.infotext_fields = [(getattr(w, attr), name) for attr, name in all_args]
self.infotext_fields = [(getattr(w, attr), name) for attr, name, *_ in ALL_ARGS]
return w.tolist()
@@ -283,7 +296,7 @@ class AfterDetailerScript(scripts.Script):
print("[-] ADetailer: ControlNetExt init failed.", file=sys.stderr)
def extra_params(self, **kwargs):
params = {name: kwargs[attr] for attr, name in all_args}
params = {name: kwargs[attr] for attr, name, *_ in ALL_ARGS}
params["ADetailer conf"] = int(params["ADetailer conf"] * 100)
if not params["ADetailer prompt"]:
@@ -470,7 +483,7 @@ class AfterDetailerScript(scripts.Script):
else:
predictor = ultralytics_predict
ad_model = model_mapping[args.ad_model]
kwargs["device"] = self.get_ultralytics_device()
kwargs["device"] = self.ultralytics_device
with ChangeTorchLoad():
pred = predictor(ad_model, pp.image, args.ad_conf, **kwargs)