mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 22:49:48 +00:00
Handle multi control inputs for control lora training
This commit is contained in:
@@ -43,7 +43,8 @@ from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
from einops import rearrange, repeat
|
||||
import torch
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline, \
|
||||
FluxAdvancedControlPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||
@@ -1244,7 +1245,7 @@ class StableDiffusion:
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||
# see if it is a control lora
|
||||
if self.adapter.control_lora is not None:
|
||||
Pipe = FluxControlPipeline
|
||||
Pipe = FluxAdvancedControlPipeline
|
||||
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
@@ -1367,6 +1368,7 @@ class StableDiffusion:
|
||||
if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None:
|
||||
validation_image = validation_image.resize((gen_config.width, gen_config.height))
|
||||
extra['control_image'] = validation_image
|
||||
extra['control_image_idx'] = gen_config.ctrl_idx
|
||||
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
|
||||
Reference in New Issue
Block a user