Added comparitive loss when training clip encoder. Allow selecting clip layer. on ip adapter. Improvements to prior prediction

This commit is contained in:
Jaret Burkett
2024-02-05 07:40:03 -07:00
parent 177c7130ec
commit e18e0cb5f8
7 changed files with 227 additions and 65 deletions

View File

@@ -225,52 +225,51 @@ class SDTrainer(BaseSDTrainProcess):
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
if self.train_config.correct_pred_norm and not is_reg:
with torch.no_grad():
# this only works if doing a prior pred
if prior_pred is not None:
prior_mean = prior_pred.mean([2,3], keepdim=True)
prior_std = prior_pred.std([2,3], keepdim=True)
noise_mean = noise_pred.mean([2,3], keepdim=True)
noise_std = noise_pred.std([2,3], keepdim=True)
target = None
if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask):
if self.train_config.correct_pred_norm and not is_reg:
with torch.no_grad():
# this only works if doing a prior pred
if prior_pred is not None:
prior_mean = prior_pred.mean([2,3], keepdim=True)
prior_std = prior_pred.std([2,3], keepdim=True)
noise_mean = noise_pred.mean([2,3], keepdim=True)
noise_std = noise_pred.std([2,3], keepdim=True)
mean_adjust = prior_mean - noise_mean
std_adjust = prior_std - noise_std
mean_adjust = prior_mean - noise_mean
std_adjust = prior_std - noise_std
mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier
std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier
mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier
std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier
target_mean = noise_mean + mean_adjust
target_std = noise_std + std_adjust
target_mean = noise_mean + mean_adjust
target_std = noise_std + std_adjust
eps = 1e-5
eps = 1e-5
# match the noise to the prior
noise = (noise - noise_mean) / (noise_std + eps)
noise = noise * (target_std + eps) + target_mean
noise = noise.detach()
# adjust the noise target to match the current knowledge of the model
# noise_mean, noise_std = get_mean_std(noise)
# match the noise to the prior
noise = (noise - noise_mean) / (noise_std + eps)
noise = noise * (target_std + eps) + target_mean
noise = noise.detach()
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
assert not self.train_config.train_turbo
# we need to make the noise prediction be a masked blending of noise and prior_pred
stretched_mask_multiplier = value_map(
mask_multiplier,
batch.file_items[0].dataset_config.mask_min_value,
1.0,
0.0,
1.0
)
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
assert not self.train_config.train_turbo
# we need to make the noise prediction be a masked blending of noise and prior_pred
stretched_mask_multiplier = value_map(
mask_multiplier,
batch.file_items[0].dataset_config.mask_min_value,
1.0,
0.0,
1.0
)
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
# target_mask_multiplier = mask_multiplier
# mask_multiplier = 1.0
target = noise
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
# set masked multiplier to 1.0 so we dont double apply it
# mask_multiplier = 1.0
# target_mask_multiplier = mask_multiplier
# mask_multiplier = 1.0
target = noise
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
# set masked multiplier to 1.0 so we dont double apply it
# mask_multiplier = 1.0
elif prior_pred is not None:
assert not self.train_config.train_turbo
# matching adapter prediction
@@ -281,6 +280,9 @@ class SDTrainer(BaseSDTrainProcess):
else:
target = noise
if target is None:
target = noise
pred = noise_pred
if self.train_config.train_turbo:
@@ -360,6 +362,13 @@ class SDTrainer(BaseSDTrainProcess):
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
# check for additional losses
if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None:
loss = loss + self.adapter.additional_loss.mean()
self.adapter.additional_loss = None
return loss
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
@@ -677,6 +686,7 @@ class SDTrainer(BaseSDTrainProcess):
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
unconditional_embeds: Optional[PromptEmbeds] = None,
conditioned_prompts=None,
**kwargs
):
# todo for embeddings, we need to run without trigger words
@@ -980,6 +990,17 @@ class SDTrainer(BaseSDTrainProcess):
# it will be injected into the tokenizer when called
self.adapter(conditional_clip_embeds)
# do the custom adapter after the prior prediction
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
quad_count = random.randint(1, 4)
self.adapter.train()
self.adapter.trigger_pre_te(
tensors_0_1=clip_images,
is_training=True,
has_been_preprocessed=True,
quad_count=quad_count
)
with self.timer('encode_prompt'):
unconditional_embeds = None
if grad_on_text_encoder:
@@ -1140,6 +1161,7 @@ class SDTrainer(BaseSDTrainProcess):
unconditional_clip_embeds = unconditional_clip_embeds.detach()
with self.timer('encode_adapter'):
self.adapter.train()
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
if self.train_config.do_cfg:
unconditional_embeds = self.adapter(unconditional_embeds.detach(),
@@ -1170,8 +1192,10 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
do_inverted_masked_prior = True
do_correct_pred_norm_prior = self.train_config.correct_pred_norm
if ((
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior):
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm):
with self.timer('prior predict'):
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,
@@ -1182,8 +1206,12 @@ class SDTrainer(BaseSDTrainProcess):
pred_kwargs=pred_kwargs,
noise=noise,
batch=batch,
unconditional_embeds=unconditional_embeds
).detach()
unconditional_embeds=unconditional_embeds,
conditioned_prompts=conditioned_prompts
)
if prior_pred is not None:
prior_pred = prior_pred.detach()
# do the custom adapter after the prior prediction
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:

View File

@@ -130,6 +130,7 @@ class NetworkConfig:
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker']
CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state']
class AdapterConfig:
def __init__(self, **kwargs):
@@ -169,6 +170,13 @@ class AdapterConfig:
self.class_names = kwargs.get('class_names', [])
self.clip_layer: CLIPLayer = kwargs.get('clip_layer', None)
if self.clip_layer is None:
if self.type.startswith('ip+'):
self.clip_layer = 'penultimate_hidden_states'
else:
self.clip_layer = 'last_hidden_state'
class EmbeddingConfig:
def __init__(self, **kwargs):

View File

@@ -438,7 +438,10 @@ class CustomAdapter(torch.nn.Module):
is_unconditional=False,
quad_count=4,
) -> PromptEmbeds:
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
if self.adapter_type == 'ilora':
return prompt_embeds
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion':
if is_unconditional:
# we dont condition the negative embeds for photo maker
return prompt_embeds.clone()
@@ -503,7 +506,7 @@ class CustomAdapter(torch.nn.Module):
self.token_mask
)
return prompt_embeds
elif self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
elif self.adapter_type == 'clip_fusion':
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
@@ -535,22 +538,96 @@ class CustomAdapter(torch.nn.Module):
if not is_training or not self.config.train_image_encoder:
img_embeds = img_embeds.detach()
if self.adapter_type == 'ilora':
self.ilora_module.img_embeds = img_embeds
return prompt_embeds
else:
prompt_embeds.text_embeds = self.clip_fusion_module(
prompt_embeds.text_embeds,
img_embeds
)
return prompt_embeds
prompt_embeds.text_embeds = self.clip_fusion_module(
prompt_embeds.text_embeds,
img_embeds
)
return prompt_embeds
else:
raise NotImplementedError
def trigger_pre_te(
self,
tensors_0_1: torch.Tensor,
is_training=False,
has_been_preprocessed=False,
quad_count=4,
) -> PromptEmbeds:
if self.adapter_type == 'ilora':
with torch.no_grad():
# on training the clip image is created in the dataloader
if not has_been_preprocessed:
# tensors should be 0-1
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
clip_image = self.image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
else:
clip_image = tensors_0_1
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
if self.config.quad_image:
# split the 4x4 grid and stack on batch
ci1, ci2 = clip_image.chunk(2, dim=2)
ci1, ci3 = ci1.chunk(2, dim=3)
ci2, ci4 = ci2.chunk(2, dim=3)
to_cat = []
for i, ci in enumerate([ci1, ci2, ci3, ci4]):
if i < quad_count:
to_cat.append(ci)
else:
break
clip_image = torch.cat(to_cat, dim=0).detach()
if self.adapter_type == 'ilora':
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
clip_image = clip_image.requires_grad_(True)
id_embeds = self.vision_encoder(
clip_image,
output_hidden_states=True,
)
else:
with torch.no_grad():
self.vision_encoder.eval()
id_embeds = self.vision_encoder(
clip_image, output_hidden_states=True
)
img_embeds = id_embeds['last_hidden_state']
if self.config.quad_image:
# get the outputs of the quat
chunks = img_embeds.chunk(quad_count, dim=0)
chunk_sum = torch.zeros_like(chunks[0])
for chunk in chunks:
chunk_sum = chunk_sum + chunk
# get the mean of them
img_embeds = chunk_sum / quad_count
if not is_training or not self.config.train_image_encoder:
img_embeds = img_embeds.detach()
self.ilora_module.img_embeds = img_embeds
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
if self.config.type == 'photo_maker':
yield from self.fuse_module.parameters(recurse)

View File

@@ -5,6 +5,7 @@ import sys
from PIL import Image
from torch.nn import Parameter
from torch.nn.modules.module import T
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
@@ -173,6 +174,7 @@ class IPAdapter(torch.nn.Module):
self.input_size = 224
self.clip_noise_zero = True
self.unconditional: torch.Tensor = None
self.additional_loss = None
if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
try:
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
@@ -451,10 +453,7 @@ class IPAdapter(torch.nn.Module):
):
with torch.no_grad():
device = self.sd_ref().unet.device
if self.config.type.startswith('ip+'):
clip_image_embeds = torch.cat([x['penultimate_hidden_states'] for x in image_embeds_list], dim=0)
else:
clip_image_embeds = torch.cat([x['image_embeds'] for x in image_embeds_list], dim=0)
clip_image_embeds = torch.cat([x[self.config.clip_layer] for x in image_embeds_list], dim=0)
if self.config.quad_image:
# get the outputs of the quat
@@ -548,7 +547,7 @@ class IPAdapter(torch.nn.Module):
# if drop:
# clip_image = clip_image * 0
with torch.set_grad_enabled(is_training):
if is_training:
if is_training and self.config.train_image_encoder:
self.image_encoder.train()
clip_image = clip_image.requires_grad_(True)
if self.preprocessor is not None:
@@ -565,16 +564,39 @@ class IPAdapter(torch.nn.Module):
clip_image, output_hidden_states=True
)
if self.config.type.startswith('ip+'):
if self.config.clip_layer == 'penultimate_hidden_states':
# they skip last layer for ip+
# https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
clip_image_embeds = clip_output.hidden_states[-2]
elif self.config.clip_layer == 'last_hidden_state':
clip_image_embeds = clip_output.hidden_states[-1]
else:
clip_image_embeds = clip_output.image_embeds
if self.config.quad_image:
# get the outputs of the quat
chunks = clip_image_embeds.chunk(quad_count, dim=0)
if self.config.train_image_encoder and is_training:
# perform a loss across all chunks this will teach the vision encoder to
# identify similarities in our pairs of images and ignore things that do not make them similar
num_losses = 0
total_loss = None
for chunk in chunks:
for chunk2 in chunks:
if chunk is not chunk2:
loss = F.mse_loss(chunk, chunk2)
if total_loss is None:
total_loss = loss
else:
total_loss = total_loss + loss
num_losses += 1
if total_loss is not None:
total_loss = total_loss / num_losses
total_loss = total_loss * 1e-2
if self.additional_loss is not None:
total_loss = total_loss + self.additional_loss
self.additional_loss = total_loss
chunk_sum = torch.zeros_like(chunks[0])
for chunk in chunks:
chunk_sum = chunk_sum + chunk
@@ -582,7 +604,7 @@ class IPAdapter(torch.nn.Module):
clip_image_embeds = chunk_sum / quad_count
if not is_training:
if not is_training or not self.config.train_image_encoder:
clip_image_embeds = clip_image_embeds.detach()
return clip_image_embeds
@@ -594,6 +616,17 @@ class IPAdapter(torch.nn.Module):
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
return embeddings
def train(self: T, mode: bool = True) -> T:
if self.config.train_image_encoder:
self.image_encoder.train(mode)
if not self.config.train_only_image_encoder:
for attn_processor in self.adapter_modules:
attn_processor.train(mode)
if self.image_proj_model is not None:
self.image_proj_model.train(mode)
return super().train(mode)
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
if self.config.train_only_image_encoder:
yield from self.image_encoder.parameters(recurse)

View File

@@ -53,7 +53,6 @@ class InstantLoRAMidModule(torch.nn.Module):
# reshape if needed
if len(x.shape) == 3:
scaler = scaler.unsqueeze(1)
x = x * scaler
except Exception as e:
print(e)
print(x.shape)

View File

@@ -354,6 +354,7 @@ class ToolkitNetworkMixin:
self.is_ssd = is_ssd
self.is_vega = is_vega
self.is_v2 = is_v2
self.is_v1 = not is_v2 and not is_sdxl and not is_ssd and not is_vega
self.is_merged_in = False
self.is_lorm = is_lorm
self.network_config: NetworkConfig = network_config
@@ -361,7 +362,7 @@ class ToolkitNetworkMixin:
self.lorm_train_mode: Literal['local', None] = None
self.can_merge_in = not is_lorm
def get_keymap(self: Network):
def get_keymap(self: Network, force_weight_mapping=False):
use_weight_mapping = False
if self.is_ssd:
@@ -377,6 +378,9 @@ class ToolkitNetworkMixin:
else:
keymap_tail = 'sd1'
# todo double check this
# use_weight_mapping = True
if force_weight_mapping:
use_weight_mapping = True
# load keymap
@@ -440,9 +444,9 @@ class ToolkitNetworkMixin:
else:
torch.save(save_dict, file)
def load_weights(self: Network, file):
def load_weights(self: Network, file, force_weight_mapping=False):
# allows us to save and load to and from ldm weights
keymap = self.get_keymap()
keymap = self.get_keymap(force_weight_mapping)
keymap = {} if keymap is None else keymap
if os.path.splitext(file)[1] == ".safetensors":
@@ -468,6 +472,11 @@ class ToolkitNetworkMixin:
for key in to_delete:
del load_sd[key]
print(f"Missing keys: {to_delete}")
if len(to_delete) > 0 and self.is_v1:
print(" Attempting to load with forced keymap")
return self.load_weights(file, force_weight_mapping=True)
info = self.load_state_dict(load_sd, False)
if len(extra_dict.keys()) == 0:
extra_dict = None

View File

@@ -528,6 +528,14 @@ class StableDiffusion:
)
gen_config.negative_prompt_2 = gen_config.negative_prompt
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
self.adapter.trigger_pre_te(
tensors_0_1=validation_image,
is_training=False,
has_been_preprocessed=False,
quad_count=4
)
# encode the prompt ourselves so we can do fun stuff with embeddings
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)