mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-24 08:19:24 +00:00
Handle inpainting training for control_lora adapter
This commit is contained in:
@@ -1246,6 +1246,8 @@ class StableDiffusion:
|
||||
# see if it is a control lora
|
||||
if self.adapter.control_lora is not None:
|
||||
Pipe = FluxAdvancedControlPipeline
|
||||
extra_args['do_inpainting'] = self.adapter.config.has_inpainting_input
|
||||
extra_args['num_controls'] = self.adapter.config.num_control_images
|
||||
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
@@ -1257,6 +1259,7 @@ class StableDiffusion:
|
||||
scheduler=noise_scheduler,
|
||||
**extra_args
|
||||
)
|
||||
|
||||
pipeline.watermark = None
|
||||
elif self.is_lumina2:
|
||||
pipeline = Lumina2Text2ImgPipeline(
|
||||
@@ -1355,7 +1358,14 @@ class StableDiffusion:
|
||||
extra = {}
|
||||
validation_image = None
|
||||
if self.adapter is not None and gen_config.adapter_image_path is not None:
|
||||
validation_image = Image.open(gen_config.adapter_image_path).convert("RGB")
|
||||
validation_image = Image.open(gen_config.adapter_image_path)
|
||||
# if the name doesnt have .inpainting. in it, make sure it is rgb
|
||||
if ".inpaint." not in gen_config.adapter_image_path:
|
||||
validation_image = validation_image.convert("RGB")
|
||||
else:
|
||||
# make sure it has an alpha
|
||||
if validation_image.mode != "RGBA":
|
||||
raise ValueError("Inpainting images must have an alpha channel")
|
||||
if isinstance(self.adapter, T2IAdapter):
|
||||
# not sure why this is double??
|
||||
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
|
||||
|
||||
Reference in New Issue
Block a user