Handle multi control inputs for control lora training

This commit is contained in:
Jaret Burkett
2025-03-23 07:37:08 -06:00
parent ccb66c748f
commit f10937e6da
7 changed files with 446 additions and 75 deletions

View File

@@ -43,7 +43,8 @@ from toolkit.train_tools import get_torch_dtype, apply_noise_offset
from einops import rearrange, repeat
import torch
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline, \
FluxAdvancedControlPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
@@ -1244,7 +1245,7 @@ class StableDiffusion:
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
# see if it is a control lora
if self.adapter.control_lora is not None:
Pipe = FluxControlPipeline
Pipe = FluxAdvancedControlPipeline
pipeline = Pipe(
vae=self.vae,
@@ -1367,6 +1368,7 @@ class StableDiffusion:
if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None:
validation_image = validation_image.resize((gen_config.width, gen_config.height))
extra['control_image'] = validation_image
extra['control_image_idx'] = gen_config.ctrl_idx
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
transform = transforms.Compose([
transforms.ToTensor(),