mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-23 07:49:24 +00:00
Added ability to train control loras. Other important bug fixes thrown in
This commit is contained in:
@@ -49,7 +49,8 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
|
||||
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \
|
||||
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Text2ImgPipeline
|
||||
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Text2ImgPipeline, \
|
||||
FluxControlPipeline
|
||||
from toolkit.models.lumina2 import Lumina2Transformer2DModel
|
||||
from toolkit.models.flex2 import Flex2Pipeline
|
||||
import diffusers
|
||||
@@ -155,6 +156,7 @@ class StableDiffusion:
|
||||
|
||||
self.model_config = model_config
|
||||
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
|
||||
self.arch = model_config.arch
|
||||
|
||||
self.device_state = None
|
||||
|
||||
@@ -1239,6 +1241,10 @@ class StableDiffusion:
|
||||
Pipe = FluxPipeline
|
||||
if self.is_flex2:
|
||||
Pipe = Flex2Pipeline
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||
# see if it is a control lora
|
||||
if self.adapter.control_lora is not None:
|
||||
Pipe = FluxControlPipeline
|
||||
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
@@ -1358,6 +1364,9 @@ class StableDiffusion:
|
||||
validation_image = validation_image.resize((gen_config.width, gen_config.height))
|
||||
extra['image'] = validation_image
|
||||
extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale
|
||||
if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None:
|
||||
validation_image = validation_image.resize((gen_config.width, gen_config.height))
|
||||
extra['control_image'] = validation_image
|
||||
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
@@ -2136,7 +2145,8 @@ class StableDiffusion:
|
||||
w=latent_model_input.shape[3] // 2,
|
||||
ph=2,
|
||||
pw=2,
|
||||
c=latent_model_input.shape[1],
|
||||
# c=latent_model_input.shape[1],
|
||||
c=self.vae.config.latent_channels
|
||||
)
|
||||
|
||||
if bypass_guidance_embedding:
|
||||
|
||||
Reference in New Issue
Block a user