Added ability to train control loras. Other important bug fixes thrown in

This commit is contained in:
Jaret Burkett
2025-03-14 18:03:00 -06:00
parent 391329dbdc
commit 3812957bc9
7 changed files with 365 additions and 19 deletions

View File

@@ -49,7 +49,8 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Text2ImgPipeline
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Text2ImgPipeline, \
FluxControlPipeline
from toolkit.models.lumina2 import Lumina2Transformer2DModel
from toolkit.models.flex2 import Flex2Pipeline
import diffusers
@@ -155,6 +156,7 @@ class StableDiffusion:
self.model_config = model_config
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
self.arch = model_config.arch
self.device_state = None
@@ -1239,6 +1241,10 @@ class StableDiffusion:
Pipe = FluxPipeline
if self.is_flex2:
Pipe = Flex2Pipeline
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
# see if it is a control lora
if self.adapter.control_lora is not None:
Pipe = FluxControlPipeline
pipeline = Pipe(
vae=self.vae,
@@ -1358,6 +1364,9 @@ class StableDiffusion:
validation_image = validation_image.resize((gen_config.width, gen_config.height))
extra['image'] = validation_image
extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale
if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None:
validation_image = validation_image.resize((gen_config.width, gen_config.height))
extra['control_image'] = validation_image
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
transform = transforms.Compose([
transforms.ToTensor(),
@@ -2136,7 +2145,8 @@ class StableDiffusion:
w=latent_model_input.shape[3] // 2,
ph=2,
pw=2,
c=latent_model_input.shape[1],
# c=latent_model_input.shape[1],
c=self.vae.config.latent_channels
)
if bypass_guidance_embedding: