mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added ability to train control loras. Other important bug fixes thrown in
This commit is contained in:
@@ -106,7 +106,7 @@ class LoRMConfig:
|
||||
})
|
||||
|
||||
|
||||
NetworkType = Literal['lora', 'locon', 'lorm']
|
||||
NetworkType = Literal['lora', 'locon', 'lorm', 'lokr']
|
||||
|
||||
|
||||
class NetworkConfig:
|
||||
@@ -151,7 +151,7 @@ class NetworkConfig:
|
||||
self.lokr_factor = kwargs.get('lokr_factor', -1)
|
||||
|
||||
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net']
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora']
|
||||
|
||||
CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state']
|
||||
|
||||
@@ -234,6 +234,13 @@ class AdapterConfig:
|
||||
# for llm adapter
|
||||
self.num_cloned_blocks: int = kwargs.get('num_cloned_blocks', 0)
|
||||
self.quantize_llm: bool = kwargs.get('quantize_llm', False)
|
||||
|
||||
# for control lora only
|
||||
lora_config: dict = kwargs.get('lora_config', None)
|
||||
if lora_config is not None:
|
||||
self.lora_config: NetworkConfig = NetworkConfig(**lora_config)
|
||||
else:
|
||||
self.lora_config = None
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
@@ -521,6 +528,32 @@ class ModelConfig:
|
||||
self.arch: ModelArch = kwargs.get("arch", None)
|
||||
|
||||
# handle migrating to new model arch
|
||||
if self.arch is not None:
|
||||
# reverse the arch to the old style
|
||||
if self.arch == 'sd2':
|
||||
self.is_v2 = True
|
||||
elif self.arch == 'sd3':
|
||||
self.is_v3 = True
|
||||
elif self.arch == 'sdxl':
|
||||
self.is_xl = True
|
||||
elif self.arch == 'pixart':
|
||||
self.is_pixart = True
|
||||
elif self.arch == 'pixart_sigma':
|
||||
self.is_pixart_sigma = True
|
||||
elif self.arch == 'auraflow':
|
||||
self.is_auraflow = True
|
||||
elif self.arch == 'flux':
|
||||
self.is_flux = True
|
||||
elif self.arch == 'flex2':
|
||||
self.is_flex2 = True
|
||||
elif self.arch == 'lumina2':
|
||||
self.is_lumina2 = True
|
||||
elif self.arch == 'vega':
|
||||
self.is_vega = True
|
||||
elif self.arch == 'ssd':
|
||||
self.is_ssd = True
|
||||
else:
|
||||
pass
|
||||
if self.arch is None:
|
||||
if kwargs.get('is_v2', False):
|
||||
self.arch = 'sd2'
|
||||
|
||||
@@ -7,8 +7,10 @@ from torch.nn import Parameter
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, CLIPTextModel, \
|
||||
CLIPTokenizer, T5Tokenizer
|
||||
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.models.clip_fusion import CLIPFusionModule
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.models.control_lora_adapter import ControlLoraAdapter
|
||||
from toolkit.models.ilora import InstantLoRAModule
|
||||
from toolkit.models.single_value_adapter import SingleValueAdapter
|
||||
from toolkit.models.te_adapter import TEAdapter
|
||||
@@ -29,7 +31,7 @@ from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProces
|
||||
AttnProcessor2_0
|
||||
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
||||
from ipadapter.ip_adapter.resampler import Resampler
|
||||
from toolkit.config_modules import AdapterConfig, AdapterTypes
|
||||
from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
import weakref
|
||||
|
||||
@@ -58,10 +60,11 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
class CustomAdapter(torch.nn.Module):
|
||||
def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'):
|
||||
def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig', train_config: 'TrainConfig'):
|
||||
super().__init__()
|
||||
self.config = adapter_config
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
self.train_config = train_config
|
||||
self.device = self.sd_ref().unet.device
|
||||
self.image_processor: CLIPImageProcessor = None
|
||||
self.input_size = 224
|
||||
@@ -97,6 +100,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.vd_adapter: VisionDirectAdapter = None
|
||||
self.single_value_adapter: SingleValueAdapter = None
|
||||
self.redux_adapter: ReduxImageEncoder = None
|
||||
self.control_lora: ControlLoraAdapter = None
|
||||
|
||||
self.conditional_embeds: Optional[torch.Tensor] = None
|
||||
self.unconditional_embeds: Optional[torch.Tensor] = None
|
||||
@@ -240,6 +244,13 @@ class CustomAdapter(torch.nn.Module):
|
||||
elif self.adapter_type == 'redux':
|
||||
vision_hidden_size = self.vision_encoder.config.hidden_size
|
||||
self.redux_adapter = ReduxImageEncoder(vision_hidden_size, 4096, self.device, torch_dtype)
|
||||
elif self.adapter_type == 'control_lora':
|
||||
self.control_lora = ControlLoraAdapter(
|
||||
self,
|
||||
sd=self.sd_ref(),
|
||||
config=self.config,
|
||||
train_config=self.train_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown adapter type: {self.adapter_type}")
|
||||
|
||||
@@ -271,7 +282,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
def setup_clip(self):
|
||||
adapter_config = self.config
|
||||
sd = self.sd_ref()
|
||||
if self.config.type in ["text_encoder", "llm_adapter", "single_value"]:
|
||||
if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora"]:
|
||||
return
|
||||
if self.config.type == 'photo_maker':
|
||||
try:
|
||||
@@ -481,6 +492,14 @@ class CustomAdapter(torch.nn.Module):
|
||||
for k2, v2 in v.items():
|
||||
new_dict[k + '.' + k2] = v2
|
||||
self.redux_adapter.load_state_dict(new_dict, strict=True)
|
||||
|
||||
if self.adapter_type == 'control_lora':
|
||||
# state dict is seperated. so recombine it
|
||||
new_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
for k2, v2 in v.items():
|
||||
new_dict[k + '.' + k2] = v2
|
||||
self.control_lora.load_weights(new_dict, strict=strict)
|
||||
|
||||
pass
|
||||
|
||||
@@ -532,6 +551,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
for k, v in d.items():
|
||||
state_dict[k] = v
|
||||
return state_dict
|
||||
elif self.adapter_type == 'control_lora':
|
||||
d = self.control_lora.get_state_dict()
|
||||
for k, v in d.items():
|
||||
state_dict[k] = v
|
||||
return state_dict
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -541,6 +565,33 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.unconditional_embeds = extra_values.to(self.device, get_torch_dtype(self.sd_ref().dtype))
|
||||
else:
|
||||
self.conditional_embeds = extra_values.to(self.device, get_torch_dtype(self.sd_ref().dtype))
|
||||
|
||||
def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO):
|
||||
with torch.no_grad():
|
||||
if self.adapter_type in ['control_lora']:
|
||||
sd: StableDiffusion = self.sd_ref()
|
||||
control_tensor = batch.control_tensor
|
||||
if control_tensor is None:
|
||||
# concat random normal noise onto the latents
|
||||
# check dimension, this is before they are rearranged
|
||||
# it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging
|
||||
latents = torch.cat((latents, torch.randn_like(latents)), dim=1)
|
||||
return latents.detach()
|
||||
# it is 0-1 need to convert to -1 to 1
|
||||
control_tensor = control_tensor * 2 - 1
|
||||
|
||||
control_tensor = control_tensor.to(sd.vae_device_torch, dtype=sd.torch_dtype)
|
||||
|
||||
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
|
||||
if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]:
|
||||
control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bicubic')
|
||||
|
||||
# encode it
|
||||
control_latent = sd.encode_images(control_tensor).to(latents.device, latents.dtype)
|
||||
# concat it onto the latents
|
||||
latents = torch.cat((latents, control_latent), dim=1)
|
||||
return latents.detach()
|
||||
return latents
|
||||
|
||||
|
||||
def condition_prompt(
|
||||
@@ -548,7 +599,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
prompt: Union[List[str], str],
|
||||
is_unconditional: bool = False,
|
||||
):
|
||||
if self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'redux':
|
||||
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora']:
|
||||
return prompt
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
# todo allow for training
|
||||
@@ -1067,6 +1118,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
yield from self.single_value_adapter.parameters(recurse)
|
||||
elif self.config.type == 'redux':
|
||||
yield from self.redux_adapter.parameters(recurse)
|
||||
elif self.config.type == 'control_lora':
|
||||
param_list = self.control_lora.get_params()
|
||||
for param in param_list:
|
||||
yield param
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
239
toolkit/models/control_lora_adapter.py
Normal file
239
toolkit/models/control_lora_adapter.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import inspect
|
||||
import weakref
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from diffusers import FluxTransformer2DModel
|
||||
# weakref
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
|
||||
|
||||
# after each step we concat the control image with the latents
|
||||
# latent_model_input = torch.cat([latents, control_image], dim=2)
|
||||
# the x_embedder has a full rank lora to handle the additional channels
|
||||
# this replaces the x_embedder with a full rank lora. on flux this is
|
||||
# x_embedder(diffusers) or img_in(bfl)
|
||||
|
||||
# Flux
|
||||
# img_in.lora_A.weight [128, 128]
|
||||
# img_in.lora_B.bias [3 072]
|
||||
# img_in.lora_B.weight [3 072, 128]
|
||||
|
||||
|
||||
class ImgEmbedder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
adapter: 'ControlLoraAdapter',
|
||||
orig_layer: torch.nn.Module,
|
||||
in_channels=128,
|
||||
out_channels=3072,
|
||||
bias=True
|
||||
):
|
||||
super().__init__()
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer)
|
||||
self.lora_A = torch.nn.Linear(in_channels, in_channels, bias=False) # lora down
|
||||
self.lora_B = torch.nn.Linear(in_channels, out_channels, bias=bias) # lora up
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls,
|
||||
model: FluxTransformer2DModel,
|
||||
adapter: 'ControlLoraAdapter',
|
||||
num_channel_multiplier=2
|
||||
):
|
||||
if model.__class__.__name__ == 'FluxTransformer2DModel':
|
||||
x_embedder: torch.nn.Linear = model.x_embedder
|
||||
img_embedder = cls(
|
||||
adapter,
|
||||
orig_layer=x_embedder,
|
||||
in_channels=x_embedder.in_features * num_channel_multiplier, # adding additional control img channels
|
||||
out_channels=x_embedder.out_features,
|
||||
bias=x_embedder.bias is not None
|
||||
)
|
||||
|
||||
# hijack the forward method
|
||||
x_embedder._orig_ctrl_lora_forward = x_embedder.forward
|
||||
x_embedder.forward = img_embedder.forward
|
||||
dtype = x_embedder.weight.dtype
|
||||
device = x_embedder.weight.device
|
||||
|
||||
# since we are adding control channels, we want those channels to be zero starting out
|
||||
# so they have no effect. It will match lora_B weight and bias, and we concat 0s for the input of the new channels
|
||||
# lora_a needs to be identity so that lora_b output matches lora_a output on init
|
||||
img_embedder.lora_A.weight.data = torch.eye(x_embedder.in_features * num_channel_multiplier).to(dtype=torch.float32, device=device)
|
||||
weight_b = x_embedder.weight.data.clone().to(dtype=torch.float32, device=device)
|
||||
# concat 0s for the new channels
|
||||
weight_b = torch.cat([weight_b, torch.zeros(weight_b.shape[0], weight_b.shape[1] * (num_channel_multiplier - 1)).to(device)], dim=1)
|
||||
img_embedder.lora_B.weight.data = weight_b.clone().to(dtype=torch.float32)
|
||||
img_embedder.lora_B.bias.data = x_embedder.bias.data.clone().to(dtype=torch.float32)
|
||||
|
||||
# update the config of the transformer
|
||||
model.config.in_channels = model.config.in_channels * num_channel_multiplier
|
||||
model.config["in_channels"] = model.config.in_channels
|
||||
|
||||
return img_embedder
|
||||
else:
|
||||
raise ValueError("Model not supported")
|
||||
|
||||
@property
|
||||
def is_active(self):
|
||||
return self.adapter_ref().is_active
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
if not self.is_active:
|
||||
# make sure lora is not active
|
||||
self.adapter_ref().control_lora.is_active = False
|
||||
return self.orig_layer_ref()._orig_ctrl_lora_forward(x)
|
||||
|
||||
# make sure lora is active
|
||||
self.adapter_ref().control_lora.is_active = True
|
||||
|
||||
orig_device = x.device
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(self.lora_A.weight.device, dtype=self.lora_A.weight.dtype)
|
||||
|
||||
x = self.lora_A(x)
|
||||
x = self.lora_B(x)
|
||||
x = x.to(orig_device, dtype=orig_dtype)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class ControlLoraAdapter(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
adapter: 'CustomAdapter',
|
||||
sd: 'StableDiffusion',
|
||||
config: 'AdapterConfig',
|
||||
train_config: 'TrainConfig'
|
||||
):
|
||||
super().__init__()
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.sd_ref = weakref.ref(sd)
|
||||
self.model_config: ModelConfig = sd.model_config
|
||||
self.network_config = config.lora_config
|
||||
self.train_config = train_config
|
||||
if self.network_config is None:
|
||||
raise ValueError("LoRA config is missing")
|
||||
|
||||
network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs
|
||||
if hasattr(sd, 'target_lora_modules'):
|
||||
network_kwargs['target_lin_modules'] = self.sd.target_lora_modules
|
||||
|
||||
if 'ignore_if_contains' not in network_kwargs:
|
||||
network_kwargs['ignore_if_contains'] = []
|
||||
|
||||
# always ignore x_embedder
|
||||
network_kwargs['ignore_if_contains'].append('x_embedder')
|
||||
|
||||
self.device_torch = sd.device_torch
|
||||
self.control_lora = LoRASpecialNetwork(
|
||||
text_encoder=sd.text_encoder,
|
||||
unet=sd.unet,
|
||||
lora_dim=self.network_config.linear,
|
||||
multiplier=1.0,
|
||||
alpha=self.network_config.linear_alpha,
|
||||
train_unet=self.train_config.train_unet,
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
conv_lora_dim=self.network_config.conv,
|
||||
conv_alpha=self.network_config.conv_alpha,
|
||||
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
|
||||
is_v2=self.model_config.is_v2,
|
||||
is_v3=self.model_config.is_v3,
|
||||
is_pixart=self.model_config.is_pixart,
|
||||
is_auraflow=self.model_config.is_auraflow,
|
||||
is_flux=self.model_config.is_flux,
|
||||
is_lumina2=self.model_config.is_lumina2,
|
||||
is_ssd=self.model_config.is_ssd,
|
||||
is_vega=self.model_config.is_vega,
|
||||
dropout=self.network_config.dropout,
|
||||
use_text_encoder_1=self.model_config.use_text_encoder_1,
|
||||
use_text_encoder_2=self.model_config.use_text_encoder_2,
|
||||
use_bias=False,
|
||||
is_lorm=False,
|
||||
network_config=self.network_config,
|
||||
network_type=self.network_config.type,
|
||||
transformer_only=self.network_config.transformer_only,
|
||||
is_transformer=sd.is_transformer,
|
||||
base_model=sd,
|
||||
**network_kwargs
|
||||
)
|
||||
self.control_lora.force_to(self.device_torch, dtype=torch.float32)
|
||||
self.control_lora._update_torch_multiplier()
|
||||
self.control_lora.apply_to(
|
||||
sd.text_encoder,
|
||||
sd.unet,
|
||||
self.train_config.train_text_encoder,
|
||||
self.train_config.train_unet
|
||||
)
|
||||
self.control_lora.can_merge_in = False
|
||||
self.control_lora.prepare_grad_etc(sd.text_encoder, sd.unet)
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.control_lora.enable_gradient_checkpointing()
|
||||
|
||||
self.x_embedder = ImgEmbedder.from_model(sd.unet, self)
|
||||
self.x_embedder.to(self.device_torch)
|
||||
|
||||
def get_params(self):
|
||||
# LyCORIS doesnt have default_lr
|
||||
config = {
|
||||
'text_encoder_lr': self.train_config.lr,
|
||||
'unet_lr': self.train_config.lr,
|
||||
}
|
||||
sig = inspect.signature(self.control_lora.prepare_optimizer_params)
|
||||
if 'default_lr' in sig.parameters:
|
||||
config['default_lr'] = self.train_config.lr
|
||||
if 'learning_rate' in sig.parameters:
|
||||
config['learning_rate'] = self.train_config.lr
|
||||
params_net = self.control_lora.prepare_optimizer_params(
|
||||
**config
|
||||
)
|
||||
|
||||
# we want only tensors here
|
||||
params = []
|
||||
for p in params_net:
|
||||
if isinstance(p, dict):
|
||||
params += p["params"]
|
||||
elif isinstance(p, torch.Tensor):
|
||||
params.append(p)
|
||||
elif isinstance(p, list):
|
||||
params += p
|
||||
|
||||
params += list(self.x_embedder.parameters())
|
||||
|
||||
# we need to be able to yield from the list like yield from params
|
||||
|
||||
return params
|
||||
|
||||
def load_weights(self, state_dict, strict=True):
|
||||
lora_sd = {}
|
||||
img_embedder_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
if "x_embedder" in key:
|
||||
new_key = key.replace("transformer.x_embedder.", "")
|
||||
img_embedder_sd[new_key] = value
|
||||
else:
|
||||
lora_sd[key] = value
|
||||
|
||||
# todo process state dict before loading
|
||||
self.control_lora.load_weights(lora_sd)
|
||||
self.x_embedder.load_state_dict(img_embedder_sd, strict=strict)
|
||||
|
||||
def get_state_dict(self):
|
||||
lora_sd = self.control_lora.get_state_dict(dtype=torch.float32)
|
||||
# todo make sure we match loras elseware.
|
||||
img_embedder_sd = self.x_embedder.state_dict()
|
||||
for key, value in img_embedder_sd.items():
|
||||
lora_sd[f"transformer.x_embedder.{key}"] = value
|
||||
return lora_sd
|
||||
|
||||
@property
|
||||
def is_active(self):
|
||||
return self.adapter_ref().is_active
|
||||
@@ -491,13 +491,8 @@ class ToolkitNetworkMixin:
|
||||
keymap = new_keymap
|
||||
|
||||
return keymap
|
||||
|
||||
def save_weights(
|
||||
self: Network,
|
||||
file, dtype=torch.float16,
|
||||
metadata=None,
|
||||
extra_state_dict: Optional[OrderedDict] = None
|
||||
):
|
||||
|
||||
def get_state_dict(self: Network, extra_state_dict=None, dtype=torch.float16):
|
||||
keymap = self.get_keymap()
|
||||
|
||||
save_keymap = {}
|
||||
@@ -506,9 +501,6 @@ class ToolkitNetworkMixin:
|
||||
# invert them
|
||||
save_keymap[diffusers_key] = ldm_key
|
||||
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
save_dict = OrderedDict()
|
||||
|
||||
@@ -556,10 +548,22 @@ class ToolkitNetworkMixin:
|
||||
save_dict = new_save_dict
|
||||
|
||||
save_dict = self.base_model_ref().convert_lora_weights_before_save(save_dict)
|
||||
return save_dict
|
||||
|
||||
def save_weights(
|
||||
self: Network,
|
||||
file, dtype=torch.float16,
|
||||
metadata=None,
|
||||
extra_state_dict: Optional[OrderedDict] = None
|
||||
):
|
||||
save_dict = self.get_state_dict(extra_state_dict=extra_state_dict, dtype=dtype)
|
||||
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
if metadata is None:
|
||||
metadata = OrderedDict()
|
||||
metadata = add_model_hash_to_meta(state_dict, metadata)
|
||||
metadata = add_model_hash_to_meta(save_dict, metadata)
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
save_file(save_dict, file, metadata)
|
||||
|
||||
@@ -49,7 +49,8 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
|
||||
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \
|
||||
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Text2ImgPipeline
|
||||
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Text2ImgPipeline, \
|
||||
FluxControlPipeline
|
||||
from toolkit.models.lumina2 import Lumina2Transformer2DModel
|
||||
from toolkit.models.flex2 import Flex2Pipeline
|
||||
import diffusers
|
||||
@@ -155,6 +156,7 @@ class StableDiffusion:
|
||||
|
||||
self.model_config = model_config
|
||||
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
|
||||
self.arch = model_config.arch
|
||||
|
||||
self.device_state = None
|
||||
|
||||
@@ -1239,6 +1241,10 @@ class StableDiffusion:
|
||||
Pipe = FluxPipeline
|
||||
if self.is_flex2:
|
||||
Pipe = Flex2Pipeline
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||
# see if it is a control lora
|
||||
if self.adapter.control_lora is not None:
|
||||
Pipe = FluxControlPipeline
|
||||
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
@@ -1358,6 +1364,9 @@ class StableDiffusion:
|
||||
validation_image = validation_image.resize((gen_config.width, gen_config.height))
|
||||
extra['image'] = validation_image
|
||||
extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale
|
||||
if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None:
|
||||
validation_image = validation_image.resize((gen_config.width, gen_config.height))
|
||||
extra['control_image'] = validation_image
|
||||
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
@@ -2136,7 +2145,8 @@ class StableDiffusion:
|
||||
w=latent_model_input.shape[3] // 2,
|
||||
ph=2,
|
||||
pw=2,
|
||||
c=latent_model_input.shape[1],
|
||||
# c=latent_model_input.shape[1],
|
||||
c=self.vae.config.latent_channels
|
||||
)
|
||||
|
||||
if bypass_guidance_embedding:
|
||||
|
||||
Reference in New Issue
Block a user