refactor(scripts): optimize get_override_settings method

Refactor the get_override_settings method in the AfterDetailerScript class to improve code readability and maintainability. Removed unused parameter '_p' and added type hints for better code documentation.
This commit is contained in:
Dowon
2024-09-22 23:54:21 +09:00
parent feefeed638
commit c439128484

View File

@@ -392,7 +392,7 @@ class AfterDetailerScript(scripts.Script):
value = args.ad_scheduler value = args.ad_scheduler
return {"scheduler": value} return {"scheduler": value}
def get_override_settings(self, p, args: ADetailerArgs) -> dict[str, Any]: def get_override_settings(self, _p, args: ADetailerArgs) -> dict[str, Any]:
d = {} d = {}
if args.ad_use_clip_skip: if args.ad_use_clip_skip:
@@ -413,7 +413,7 @@ class AfterDetailerScript(scripts.Script):
d["sd_vae"] = args.ad_vae d["sd_vae"] = args.ad_vae
return d return d
def get_initial_noise_multiplier(self, p, args: ADetailerArgs) -> float | None: def get_initial_noise_multiplier(self, _p, args: ADetailerArgs) -> float | None:
return args.ad_noise_multiplier if args.ad_use_noise_multiplier else None return args.ad_noise_multiplier if args.ad_use_noise_multiplier else None
@staticmethod @staticmethod
@@ -495,7 +495,9 @@ class AfterDetailerScript(scripts.Script):
return new_args return new_args
def get_i2i_p(self, p, args: ADetailerArgs, image): def get_i2i_p(
self, p, args: ADetailerArgs, image: Image.Image
) -> StableDiffusionProcessingImg2Img:
seed, subseed = self.get_seed(p) seed, subseed = self.get_seed(p)
width, height = self.get_width_height(p, args) width, height = self.get_width_height(p, args)
steps = self.get_steps(p, args) steps = self.get_steps(p, args)
@@ -563,6 +565,9 @@ class AfterDetailerScript(scripts.Script):
return i2i return i2i
def save_image(self, p, image, *, condition: str, suffix: str) -> None: def save_image(self, p, image, *, condition: str, suffix: str) -> None:
if not opts.data.get(condition, False):
return
i = get_i(p) i = get_i(p)
if p.all_prompts: if p.all_prompts:
i %= len(p.all_prompts) i %= len(p.all_prompts)
@@ -571,23 +576,22 @@ class AfterDetailerScript(scripts.Script):
save_prompt = p.prompt save_prompt = p.prompt
seed, _ = self.get_seed(p) seed, _ = self.get_seed(p)
if opts.data.get(condition, False): ad_save_images_dir: str = opts.data.get("ad_save_images_dir", "")
ad_save_images_dir: str = opts.data.get("ad_save_images_dir", "")
if not ad_save_images_dir.strip(): if not ad_save_images_dir.strip():
ad_save_images_dir = p.outpath_samples ad_save_images_dir = p.outpath_samples
images.save_image( images.save_image(
image=image, image=image,
path=ad_save_images_dir, path=ad_save_images_dir,
basename="", basename="",
seed=seed, seed=seed,
prompt=save_prompt, prompt=save_prompt,
extension=opts.samples_format, extension=opts.samples_format,
info=self.infotext(p), info=self.infotext(p),
p=p, p=p,
suffix=suffix, suffix=suffix,
) )
def get_ad_model(self, name: str): def get_ad_model(self, name: str):
if name not in model_mapping: if name not in model_mapping:
@@ -671,14 +675,9 @@ class AfterDetailerScript(scripts.Script):
mask = ImageChops.invert(mask) mask = ImageChops.invert(mask)
mask = create_binary_mask(mask) mask = create_binary_mask(mask)
if is_skip_img2img(p): width, height = p.width, p.height
if hasattr(p, "init_images") and p.init_images: if is_skip_img2img(p) and hasattr(p, "init_images") and p.init_images:
width, height = p.init_images[0].size width, height = p.init_images[0].size
else:
msg = "[-] ADetailer: no init_images."
raise RuntimeError(msg)
else:
width, height = p.width, p.height
return images.resize_image(p.resize_mode, mask, width, height) return images.resize_image(p.resize_mode, mask, width, height)
@staticmethod @staticmethod