Handle inpainting training for control_lora adapter

This commit is contained in:
Jaret Burkett
2025-03-24 13:17:47 -06:00
parent f10937e6da
commit 45be82d5d6
9 changed files with 257 additions and 23 deletions

View File

@@ -1246,6 +1246,8 @@ class StableDiffusion:
# see if it is a control lora
if self.adapter.control_lora is not None:
Pipe = FluxAdvancedControlPipeline
extra_args['do_inpainting'] = self.adapter.config.has_inpainting_input
extra_args['num_controls'] = self.adapter.config.num_control_images
pipeline = Pipe(
vae=self.vae,
@@ -1257,6 +1259,7 @@ class StableDiffusion:
scheduler=noise_scheduler,
**extra_args
)
pipeline.watermark = None
elif self.is_lumina2:
pipeline = Lumina2Text2ImgPipeline(
@@ -1355,7 +1358,14 @@ class StableDiffusion:
extra = {}
validation_image = None
if self.adapter is not None and gen_config.adapter_image_path is not None:
validation_image = Image.open(gen_config.adapter_image_path).convert("RGB")
validation_image = Image.open(gen_config.adapter_image_path)
# if the name doesnt have .inpainting. in it, make sure it is rgb
if ".inpaint." not in gen_config.adapter_image_path:
validation_image = validation_image.convert("RGB")
else:
# make sure it has an alpha
if validation_image.mode != "RGBA":
raise ValueError("Inpainting images must have an alpha channel")
if isinstance(self.adapter, T2IAdapter):
# not sure why this is double??
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))