mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added comparitive loss when training clip encoder. Allow selecting clip layer. on ip adapter. Improvements to prior prediction
This commit is contained in:
@@ -225,52 +225,51 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
|
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
|
||||||
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
|
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
|
||||||
|
|
||||||
if self.train_config.correct_pred_norm and not is_reg:
|
target = None
|
||||||
with torch.no_grad():
|
if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask):
|
||||||
# this only works if doing a prior pred
|
if self.train_config.correct_pred_norm and not is_reg:
|
||||||
if prior_pred is not None:
|
with torch.no_grad():
|
||||||
prior_mean = prior_pred.mean([2,3], keepdim=True)
|
# this only works if doing a prior pred
|
||||||
prior_std = prior_pred.std([2,3], keepdim=True)
|
if prior_pred is not None:
|
||||||
noise_mean = noise_pred.mean([2,3], keepdim=True)
|
prior_mean = prior_pred.mean([2,3], keepdim=True)
|
||||||
noise_std = noise_pred.std([2,3], keepdim=True)
|
prior_std = prior_pred.std([2,3], keepdim=True)
|
||||||
|
noise_mean = noise_pred.mean([2,3], keepdim=True)
|
||||||
|
noise_std = noise_pred.std([2,3], keepdim=True)
|
||||||
|
|
||||||
mean_adjust = prior_mean - noise_mean
|
mean_adjust = prior_mean - noise_mean
|
||||||
std_adjust = prior_std - noise_std
|
std_adjust = prior_std - noise_std
|
||||||
|
|
||||||
mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier
|
mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier
|
||||||
std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier
|
std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier
|
||||||
|
|
||||||
target_mean = noise_mean + mean_adjust
|
target_mean = noise_mean + mean_adjust
|
||||||
target_std = noise_std + std_adjust
|
target_std = noise_std + std_adjust
|
||||||
|
|
||||||
eps = 1e-5
|
eps = 1e-5
|
||||||
|
# match the noise to the prior
|
||||||
|
noise = (noise - noise_mean) / (noise_std + eps)
|
||||||
|
noise = noise * (target_std + eps) + target_mean
|
||||||
|
noise = noise.detach()
|
||||||
|
|
||||||
# adjust the noise target to match the current knowledge of the model
|
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
|
||||||
# noise_mean, noise_std = get_mean_std(noise)
|
assert not self.train_config.train_turbo
|
||||||
# match the noise to the prior
|
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
||||||
noise = (noise - noise_mean) / (noise_std + eps)
|
stretched_mask_multiplier = value_map(
|
||||||
noise = noise * (target_std + eps) + target_mean
|
mask_multiplier,
|
||||||
noise = noise.detach()
|
batch.file_items[0].dataset_config.mask_min_value,
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
1.0
|
||||||
|
)
|
||||||
|
|
||||||
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
|
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
|
||||||
assert not self.train_config.train_turbo
|
|
||||||
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
|
||||||
stretched_mask_multiplier = value_map(
|
|
||||||
mask_multiplier,
|
|
||||||
batch.file_items[0].dataset_config.mask_min_value,
|
|
||||||
1.0,
|
|
||||||
0.0,
|
|
||||||
1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
|
# target_mask_multiplier = mask_multiplier
|
||||||
|
# mask_multiplier = 1.0
|
||||||
# target_mask_multiplier = mask_multiplier
|
target = noise
|
||||||
# mask_multiplier = 1.0
|
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
|
||||||
target = noise
|
# set masked multiplier to 1.0 so we dont double apply it
|
||||||
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
|
# mask_multiplier = 1.0
|
||||||
# set masked multiplier to 1.0 so we dont double apply it
|
|
||||||
# mask_multiplier = 1.0
|
|
||||||
elif prior_pred is not None:
|
elif prior_pred is not None:
|
||||||
assert not self.train_config.train_turbo
|
assert not self.train_config.train_turbo
|
||||||
# matching adapter prediction
|
# matching adapter prediction
|
||||||
@@ -281,6 +280,9 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
|
if target is None:
|
||||||
|
target = noise
|
||||||
|
|
||||||
pred = noise_pred
|
pred = noise_pred
|
||||||
|
|
||||||
if self.train_config.train_turbo:
|
if self.train_config.train_turbo:
|
||||||
@@ -360,6 +362,13 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||||
|
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
|
|
||||||
|
# check for additional losses
|
||||||
|
if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None:
|
||||||
|
|
||||||
|
loss = loss + self.adapter.additional_loss.mean()
|
||||||
|
self.adapter.additional_loss = None
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||||
@@ -677,6 +686,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
batch: 'DataLoaderBatchDTO',
|
batch: 'DataLoaderBatchDTO',
|
||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||||
|
conditioned_prompts=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
# todo for embeddings, we need to run without trigger words
|
# todo for embeddings, we need to run without trigger words
|
||||||
@@ -980,6 +990,17 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# it will be injected into the tokenizer when called
|
# it will be injected into the tokenizer when called
|
||||||
self.adapter(conditional_clip_embeds)
|
self.adapter(conditional_clip_embeds)
|
||||||
|
|
||||||
|
# do the custom adapter after the prior prediction
|
||||||
|
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
|
||||||
|
quad_count = random.randint(1, 4)
|
||||||
|
self.adapter.train()
|
||||||
|
self.adapter.trigger_pre_te(
|
||||||
|
tensors_0_1=clip_images,
|
||||||
|
is_training=True,
|
||||||
|
has_been_preprocessed=True,
|
||||||
|
quad_count=quad_count
|
||||||
|
)
|
||||||
|
|
||||||
with self.timer('encode_prompt'):
|
with self.timer('encode_prompt'):
|
||||||
unconditional_embeds = None
|
unconditional_embeds = None
|
||||||
if grad_on_text_encoder:
|
if grad_on_text_encoder:
|
||||||
@@ -1140,6 +1161,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
unconditional_clip_embeds = unconditional_clip_embeds.detach()
|
unconditional_clip_embeds = unconditional_clip_embeds.detach()
|
||||||
|
|
||||||
with self.timer('encode_adapter'):
|
with self.timer('encode_adapter'):
|
||||||
|
self.adapter.train()
|
||||||
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
|
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
|
||||||
if self.train_config.do_cfg:
|
if self.train_config.do_cfg:
|
||||||
unconditional_embeds = self.adapter(unconditional_embeds.detach(),
|
unconditional_embeds = self.adapter(unconditional_embeds.detach(),
|
||||||
@@ -1170,8 +1192,10 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
|
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
|
||||||
do_inverted_masked_prior = True
|
do_inverted_masked_prior = True
|
||||||
|
|
||||||
|
do_correct_pred_norm_prior = self.train_config.correct_pred_norm
|
||||||
|
|
||||||
if ((
|
if ((
|
||||||
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior):
|
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm):
|
||||||
with self.timer('prior predict'):
|
with self.timer('prior predict'):
|
||||||
prior_pred = self.get_prior_prediction(
|
prior_pred = self.get_prior_prediction(
|
||||||
noisy_latents=noisy_latents,
|
noisy_latents=noisy_latents,
|
||||||
@@ -1182,8 +1206,12 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
pred_kwargs=pred_kwargs,
|
pred_kwargs=pred_kwargs,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
unconditional_embeds=unconditional_embeds
|
unconditional_embeds=unconditional_embeds,
|
||||||
).detach()
|
conditioned_prompts=conditioned_prompts
|
||||||
|
)
|
||||||
|
if prior_pred is not None:
|
||||||
|
prior_pred = prior_pred.detach()
|
||||||
|
|
||||||
|
|
||||||
# do the custom adapter after the prior prediction
|
# do the custom adapter after the prior prediction
|
||||||
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
|
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
|
||||||
|
|||||||
@@ -130,6 +130,7 @@ class NetworkConfig:
|
|||||||
|
|
||||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker']
|
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker']
|
||||||
|
|
||||||
|
CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state']
|
||||||
|
|
||||||
class AdapterConfig:
|
class AdapterConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@@ -169,6 +170,13 @@ class AdapterConfig:
|
|||||||
|
|
||||||
self.class_names = kwargs.get('class_names', [])
|
self.class_names = kwargs.get('class_names', [])
|
||||||
|
|
||||||
|
self.clip_layer: CLIPLayer = kwargs.get('clip_layer', None)
|
||||||
|
if self.clip_layer is None:
|
||||||
|
if self.type.startswith('ip+'):
|
||||||
|
self.clip_layer = 'penultimate_hidden_states'
|
||||||
|
else:
|
||||||
|
self.clip_layer = 'last_hidden_state'
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingConfig:
|
class EmbeddingConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
|||||||
@@ -438,7 +438,10 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
is_unconditional=False,
|
is_unconditional=False,
|
||||||
quad_count=4,
|
quad_count=4,
|
||||||
) -> PromptEmbeds:
|
) -> PromptEmbeds:
|
||||||
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
|
if self.adapter_type == 'ilora':
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion':
|
||||||
if is_unconditional:
|
if is_unconditional:
|
||||||
# we dont condition the negative embeds for photo maker
|
# we dont condition the negative embeds for photo maker
|
||||||
return prompt_embeds.clone()
|
return prompt_embeds.clone()
|
||||||
@@ -503,7 +506,7 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
self.token_mask
|
self.token_mask
|
||||||
)
|
)
|
||||||
return prompt_embeds
|
return prompt_embeds
|
||||||
elif self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
|
elif self.adapter_type == 'clip_fusion':
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
if is_training and self.config.train_image_encoder:
|
if is_training and self.config.train_image_encoder:
|
||||||
self.vision_encoder.train()
|
self.vision_encoder.train()
|
||||||
@@ -535,22 +538,96 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
if not is_training or not self.config.train_image_encoder:
|
if not is_training or not self.config.train_image_encoder:
|
||||||
img_embeds = img_embeds.detach()
|
img_embeds = img_embeds.detach()
|
||||||
|
|
||||||
if self.adapter_type == 'ilora':
|
|
||||||
self.ilora_module.img_embeds = img_embeds
|
|
||||||
|
|
||||||
return prompt_embeds
|
prompt_embeds.text_embeds = self.clip_fusion_module(
|
||||||
else:
|
prompt_embeds.text_embeds,
|
||||||
|
img_embeds
|
||||||
prompt_embeds.text_embeds = self.clip_fusion_module(
|
)
|
||||||
prompt_embeds.text_embeds,
|
return prompt_embeds
|
||||||
img_embeds
|
|
||||||
)
|
|
||||||
return prompt_embeds
|
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def trigger_pre_te(
|
||||||
|
self,
|
||||||
|
tensors_0_1: torch.Tensor,
|
||||||
|
is_training=False,
|
||||||
|
has_been_preprocessed=False,
|
||||||
|
quad_count=4,
|
||||||
|
) -> PromptEmbeds:
|
||||||
|
if self.adapter_type == 'ilora':
|
||||||
|
with torch.no_grad():
|
||||||
|
# on training the clip image is created in the dataloader
|
||||||
|
if not has_been_preprocessed:
|
||||||
|
# tensors should be 0-1
|
||||||
|
if tensors_0_1.ndim == 3:
|
||||||
|
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||||
|
# training tensors are 0 - 1
|
||||||
|
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||||
|
# if images are out of this range throw error
|
||||||
|
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||||
|
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||||
|
tensors_0_1.min(), tensors_0_1.max()
|
||||||
|
))
|
||||||
|
clip_image = self.image_processor(
|
||||||
|
images=tensors_0_1,
|
||||||
|
return_tensors="pt",
|
||||||
|
do_resize=True,
|
||||||
|
do_rescale=False,
|
||||||
|
).pixel_values
|
||||||
|
else:
|
||||||
|
clip_image = tensors_0_1
|
||||||
|
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
||||||
|
|
||||||
|
if self.config.quad_image:
|
||||||
|
# split the 4x4 grid and stack on batch
|
||||||
|
ci1, ci2 = clip_image.chunk(2, dim=2)
|
||||||
|
ci1, ci3 = ci1.chunk(2, dim=3)
|
||||||
|
ci2, ci4 = ci2.chunk(2, dim=3)
|
||||||
|
to_cat = []
|
||||||
|
for i, ci in enumerate([ci1, ci2, ci3, ci4]):
|
||||||
|
if i < quad_count:
|
||||||
|
to_cat.append(ci)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
clip_image = torch.cat(to_cat, dim=0).detach()
|
||||||
|
|
||||||
|
if self.adapter_type == 'ilora':
|
||||||
|
with torch.set_grad_enabled(is_training):
|
||||||
|
if is_training and self.config.train_image_encoder:
|
||||||
|
self.vision_encoder.train()
|
||||||
|
clip_image = clip_image.requires_grad_(True)
|
||||||
|
id_embeds = self.vision_encoder(
|
||||||
|
clip_image,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
self.vision_encoder.eval()
|
||||||
|
id_embeds = self.vision_encoder(
|
||||||
|
clip_image, output_hidden_states=True
|
||||||
|
)
|
||||||
|
|
||||||
|
img_embeds = id_embeds['last_hidden_state']
|
||||||
|
|
||||||
|
if self.config.quad_image:
|
||||||
|
# get the outputs of the quat
|
||||||
|
chunks = img_embeds.chunk(quad_count, dim=0)
|
||||||
|
chunk_sum = torch.zeros_like(chunks[0])
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_sum = chunk_sum + chunk
|
||||||
|
# get the mean of them
|
||||||
|
|
||||||
|
img_embeds = chunk_sum / quad_count
|
||||||
|
|
||||||
|
|
||||||
|
if not is_training or not self.config.train_image_encoder:
|
||||||
|
img_embeds = img_embeds.detach()
|
||||||
|
|
||||||
|
self.ilora_module.img_embeds = img_embeds
|
||||||
|
|
||||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||||
if self.config.type == 'photo_maker':
|
if self.config.type == 'photo_maker':
|
||||||
yield from self.fuse_module.parameters(recurse)
|
yield from self.fuse_module.parameters(recurse)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import sys
|
|||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
from torch.nn.modules.module import T
|
||||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||||
@@ -173,6 +174,7 @@ class IPAdapter(torch.nn.Module):
|
|||||||
self.input_size = 224
|
self.input_size = 224
|
||||||
self.clip_noise_zero = True
|
self.clip_noise_zero = True
|
||||||
self.unconditional: torch.Tensor = None
|
self.unconditional: torch.Tensor = None
|
||||||
|
self.additional_loss = None
|
||||||
if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
|
if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
|
||||||
try:
|
try:
|
||||||
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||||
@@ -451,10 +453,7 @@ class IPAdapter(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
device = self.sd_ref().unet.device
|
device = self.sd_ref().unet.device
|
||||||
if self.config.type.startswith('ip+'):
|
clip_image_embeds = torch.cat([x[self.config.clip_layer] for x in image_embeds_list], dim=0)
|
||||||
clip_image_embeds = torch.cat([x['penultimate_hidden_states'] for x in image_embeds_list], dim=0)
|
|
||||||
else:
|
|
||||||
clip_image_embeds = torch.cat([x['image_embeds'] for x in image_embeds_list], dim=0)
|
|
||||||
|
|
||||||
if self.config.quad_image:
|
if self.config.quad_image:
|
||||||
# get the outputs of the quat
|
# get the outputs of the quat
|
||||||
@@ -548,7 +547,7 @@ class IPAdapter(torch.nn.Module):
|
|||||||
# if drop:
|
# if drop:
|
||||||
# clip_image = clip_image * 0
|
# clip_image = clip_image * 0
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
if is_training:
|
if is_training and self.config.train_image_encoder:
|
||||||
self.image_encoder.train()
|
self.image_encoder.train()
|
||||||
clip_image = clip_image.requires_grad_(True)
|
clip_image = clip_image.requires_grad_(True)
|
||||||
if self.preprocessor is not None:
|
if self.preprocessor is not None:
|
||||||
@@ -565,16 +564,39 @@ class IPAdapter(torch.nn.Module):
|
|||||||
clip_image, output_hidden_states=True
|
clip_image, output_hidden_states=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.type.startswith('ip+'):
|
if self.config.clip_layer == 'penultimate_hidden_states':
|
||||||
# they skip last layer for ip+
|
# they skip last layer for ip+
|
||||||
# https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
|
# https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
|
||||||
clip_image_embeds = clip_output.hidden_states[-2]
|
clip_image_embeds = clip_output.hidden_states[-2]
|
||||||
|
elif self.config.clip_layer == 'last_hidden_state':
|
||||||
|
clip_image_embeds = clip_output.hidden_states[-1]
|
||||||
else:
|
else:
|
||||||
clip_image_embeds = clip_output.image_embeds
|
clip_image_embeds = clip_output.image_embeds
|
||||||
|
|
||||||
if self.config.quad_image:
|
if self.config.quad_image:
|
||||||
# get the outputs of the quat
|
# get the outputs of the quat
|
||||||
chunks = clip_image_embeds.chunk(quad_count, dim=0)
|
chunks = clip_image_embeds.chunk(quad_count, dim=0)
|
||||||
|
if self.config.train_image_encoder and is_training:
|
||||||
|
# perform a loss across all chunks this will teach the vision encoder to
|
||||||
|
# identify similarities in our pairs of images and ignore things that do not make them similar
|
||||||
|
num_losses = 0
|
||||||
|
total_loss = None
|
||||||
|
for chunk in chunks:
|
||||||
|
for chunk2 in chunks:
|
||||||
|
if chunk is not chunk2:
|
||||||
|
loss = F.mse_loss(chunk, chunk2)
|
||||||
|
if total_loss is None:
|
||||||
|
total_loss = loss
|
||||||
|
else:
|
||||||
|
total_loss = total_loss + loss
|
||||||
|
num_losses += 1
|
||||||
|
if total_loss is not None:
|
||||||
|
total_loss = total_loss / num_losses
|
||||||
|
total_loss = total_loss * 1e-2
|
||||||
|
if self.additional_loss is not None:
|
||||||
|
total_loss = total_loss + self.additional_loss
|
||||||
|
self.additional_loss = total_loss
|
||||||
|
|
||||||
chunk_sum = torch.zeros_like(chunks[0])
|
chunk_sum = torch.zeros_like(chunks[0])
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
chunk_sum = chunk_sum + chunk
|
chunk_sum = chunk_sum + chunk
|
||||||
@@ -582,7 +604,7 @@ class IPAdapter(torch.nn.Module):
|
|||||||
|
|
||||||
clip_image_embeds = chunk_sum / quad_count
|
clip_image_embeds = chunk_sum / quad_count
|
||||||
|
|
||||||
if not is_training:
|
if not is_training or not self.config.train_image_encoder:
|
||||||
clip_image_embeds = clip_image_embeds.detach()
|
clip_image_embeds = clip_image_embeds.detach()
|
||||||
|
|
||||||
return clip_image_embeds
|
return clip_image_embeds
|
||||||
@@ -594,6 +616,17 @@ class IPAdapter(torch.nn.Module):
|
|||||||
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
|
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def train(self: T, mode: bool = True) -> T:
|
||||||
|
if self.config.train_image_encoder:
|
||||||
|
self.image_encoder.train(mode)
|
||||||
|
if not self.config.train_only_image_encoder:
|
||||||
|
for attn_processor in self.adapter_modules:
|
||||||
|
attn_processor.train(mode)
|
||||||
|
if self.image_proj_model is not None:
|
||||||
|
self.image_proj_model.train(mode)
|
||||||
|
return super().train(mode)
|
||||||
|
|
||||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||||
if self.config.train_only_image_encoder:
|
if self.config.train_only_image_encoder:
|
||||||
yield from self.image_encoder.parameters(recurse)
|
yield from self.image_encoder.parameters(recurse)
|
||||||
|
|||||||
@@ -53,7 +53,6 @@ class InstantLoRAMidModule(torch.nn.Module):
|
|||||||
# reshape if needed
|
# reshape if needed
|
||||||
if len(x.shape) == 3:
|
if len(x.shape) == 3:
|
||||||
scaler = scaler.unsqueeze(1)
|
scaler = scaler.unsqueeze(1)
|
||||||
x = x * scaler
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
print(x.shape)
|
print(x.shape)
|
||||||
|
|||||||
@@ -354,6 +354,7 @@ class ToolkitNetworkMixin:
|
|||||||
self.is_ssd = is_ssd
|
self.is_ssd = is_ssd
|
||||||
self.is_vega = is_vega
|
self.is_vega = is_vega
|
||||||
self.is_v2 = is_v2
|
self.is_v2 = is_v2
|
||||||
|
self.is_v1 = not is_v2 and not is_sdxl and not is_ssd and not is_vega
|
||||||
self.is_merged_in = False
|
self.is_merged_in = False
|
||||||
self.is_lorm = is_lorm
|
self.is_lorm = is_lorm
|
||||||
self.network_config: NetworkConfig = network_config
|
self.network_config: NetworkConfig = network_config
|
||||||
@@ -361,7 +362,7 @@ class ToolkitNetworkMixin:
|
|||||||
self.lorm_train_mode: Literal['local', None] = None
|
self.lorm_train_mode: Literal['local', None] = None
|
||||||
self.can_merge_in = not is_lorm
|
self.can_merge_in = not is_lorm
|
||||||
|
|
||||||
def get_keymap(self: Network):
|
def get_keymap(self: Network, force_weight_mapping=False):
|
||||||
use_weight_mapping = False
|
use_weight_mapping = False
|
||||||
|
|
||||||
if self.is_ssd:
|
if self.is_ssd:
|
||||||
@@ -377,6 +378,9 @@ class ToolkitNetworkMixin:
|
|||||||
else:
|
else:
|
||||||
keymap_tail = 'sd1'
|
keymap_tail = 'sd1'
|
||||||
# todo double check this
|
# todo double check this
|
||||||
|
# use_weight_mapping = True
|
||||||
|
|
||||||
|
if force_weight_mapping:
|
||||||
use_weight_mapping = True
|
use_weight_mapping = True
|
||||||
|
|
||||||
# load keymap
|
# load keymap
|
||||||
@@ -440,9 +444,9 @@ class ToolkitNetworkMixin:
|
|||||||
else:
|
else:
|
||||||
torch.save(save_dict, file)
|
torch.save(save_dict, file)
|
||||||
|
|
||||||
def load_weights(self: Network, file):
|
def load_weights(self: Network, file, force_weight_mapping=False):
|
||||||
# allows us to save and load to and from ldm weights
|
# allows us to save and load to and from ldm weights
|
||||||
keymap = self.get_keymap()
|
keymap = self.get_keymap(force_weight_mapping)
|
||||||
keymap = {} if keymap is None else keymap
|
keymap = {} if keymap is None else keymap
|
||||||
|
|
||||||
if os.path.splitext(file)[1] == ".safetensors":
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
@@ -468,6 +472,11 @@ class ToolkitNetworkMixin:
|
|||||||
for key in to_delete:
|
for key in to_delete:
|
||||||
del load_sd[key]
|
del load_sd[key]
|
||||||
|
|
||||||
|
print(f"Missing keys: {to_delete}")
|
||||||
|
if len(to_delete) > 0 and self.is_v1:
|
||||||
|
print(" Attempting to load with forced keymap")
|
||||||
|
return self.load_weights(file, force_weight_mapping=True)
|
||||||
|
|
||||||
info = self.load_state_dict(load_sd, False)
|
info = self.load_state_dict(load_sd, False)
|
||||||
if len(extra_dict.keys()) == 0:
|
if len(extra_dict.keys()) == 0:
|
||||||
extra_dict = None
|
extra_dict = None
|
||||||
|
|||||||
@@ -528,6 +528,14 @@ class StableDiffusion:
|
|||||||
)
|
)
|
||||||
gen_config.negative_prompt_2 = gen_config.negative_prompt
|
gen_config.negative_prompt_2 = gen_config.negative_prompt
|
||||||
|
|
||||||
|
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||||
|
self.adapter.trigger_pre_te(
|
||||||
|
tensors_0_1=validation_image,
|
||||||
|
is_training=False,
|
||||||
|
has_been_preprocessed=False,
|
||||||
|
quad_count=4
|
||||||
|
)
|
||||||
|
|
||||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||||
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user