Fixed issue with adapters that only had 1 input channel. Added ability to set the percentage chance of adapter matching

This commit is contained in:
Jaret Burkett
2023-10-15 15:13:35 -06:00
parent b1a22d0b3e
commit a05459afaf
2 changed files with 79 additions and 38 deletions

View File

@@ -22,6 +22,7 @@ def flush():
torch.cuda.empty_cache()
gc.collect()
adapter_transforms = transforms.Compose([
# transforms.PILToTensor(),
transforms.ToTensor(),
@@ -51,7 +52,6 @@ class SDTrainer(BaseSDTrainProcess):
self.assistant_adapter.requires_grad_(False)
flush()
def hook_before_train_loop(self):
# move vae to device if we did not cache latents
if not self.is_latents_cached:
@@ -62,7 +62,6 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.vae.to('cpu')
flush()
def hook_train_loop(self, batch):
self.timer.start('preprocess_batch')
@@ -72,6 +71,17 @@ class SDTrainer(BaseSDTrainProcess):
has_adapter_img = batch.control_tensor is not None
match_adapter_assist = False
# check if we are matching the adapter assistant
if self.assistant_adapter:
if self.train_config.match_adapter_chance == 1.0:
match_adapter_assist = True
elif self.train_config.match_adapter_chance > 0.0:
match_adapter_assist = torch.rand(
(1,), device=self.device_torch, dtype=dtype
) < self.train_config.match_adapter_chance
self.timer.stop('preprocess_batch')
with torch.no_grad():
@@ -82,6 +92,12 @@ class SDTrainer(BaseSDTrainProcess):
# todo move this to data loader
if batch.control_tensor is not None:
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
# match in channels
if self.assistant_adapter is not None:
in_channels = self.assistant_adapter.config.in_channels
if adapter_images.shape[1] != in_channels:
# we need to match the channels
adapter_images = adapter_images[:, :in_channels, :, :]
else:
raise NotImplementedError("Adapter images now must be loaded with dataloader")
# not 100% sure what this does. But they do it here
@@ -106,7 +122,7 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter and isinstance(self.adapter, T2IAdapter):
# training a t2i adapter, not using as assistant.
return 1.0
elif self.train_config.match_adapter_assist:
elif match_adapter_assist:
# training a texture. We want it high
adapter_strength_min = 0.9
adapter_strength_max = 1.0
@@ -117,18 +133,18 @@ class SDTrainer(BaseSDTrainProcess):
adapter_strength_min = 0.9
adapter_strength_max = 1.1
adapter_conditioning_scale = torch.rand(
(1,), device=self.device_torch, dtype=dtype
)
adapter_conditioning_scale = torch.rand(
(1,), device=self.device_torch, dtype=dtype
)
adapter_conditioning_scale = value_map(
adapter_conditioning_scale,
0.0,
1.0,
adapter_strength_min,
adapter_strength_max
)
return adapter_conditioning_scale
adapter_conditioning_scale = value_map(
adapter_conditioning_scale,
0.0,
1.0,
adapter_strength_min,
adapter_strength_max
)
return adapter_conditioning_scale
# flush()
with self.timer('grad_setup'):
@@ -154,11 +170,26 @@ class SDTrainer(BaseSDTrainProcess):
# activate network if it exits
with network:
with self.timer('encode_prompt'):
with torch.set_grad_enabled(grad_on_text_encoder):
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype)
if not grad_on_text_encoder:
if grad_on_text_encoder:
with torch.set_grad_enabled(True):
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(
self.device_torch,
dtype=dtype)
else:
with torch.set_grad_enabled(False):
# make sure it is in eval mode
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.eval()
else:
self.sd.text_encoder.eval()
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(
self.device_torch,
dtype=dtype)
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
# flush()
pred_kwargs = {}
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
@@ -170,33 +201,38 @@ class SDTrainer(BaseSDTrainProcess):
if self.assistant_adapter:
# not training. detach
down_block_additional_residuals = [
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in down_block_additional_residuals
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
down_block_additional_residuals
]
else:
down_block_additional_residuals = [
sample.to(dtype=dtype) * adapter_multiplier for sample in down_block_additional_residuals
sample.to(dtype=dtype) * adapter_multiplier for sample in
down_block_additional_residuals
]
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
control_pred = None
if has_adapter_img and self.assistant_adapter and self.train_config.match_adapter_assist:
# do a prediction here so we can match its output with network multiplier set to 0.0
with torch.no_grad():
# dont use network on this
network.multiplier = 0.0
control_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
control_pred = control_pred.detach()
# remove the residuals as we wont use them on prediction when matching control
del pred_kwargs['down_block_additional_residuals']
# restore network
network.multiplier = network_weight_list
if has_adapter_img and self.assistant_adapter and match_adapter_assist:
with self.timer('predict_with_adapter'):
# do a prediction here so we can match its output with network multiplier set to 0.0
with torch.no_grad():
# dont use network on this
network.multiplier = 0.0
self.sd.unet.eval()
control_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
self.sd.unet.train()
control_pred = control_pred.detach()
# remove the residuals as we wont use them on prediction when matching control
del pred_kwargs['down_block_additional_residuals']
# restore network
network.multiplier = network_weight_list
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter'):
@@ -204,7 +240,6 @@ class SDTrainer(BaseSDTrainProcess):
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
with self.timer('predict_unet'):
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),

View File

@@ -122,7 +122,13 @@ class TrainConfig:
self.start_step = kwargs.get('start_step', None)
self.free_u = kwargs.get('free_u', False)
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
self.match_adapter_assist = kwargs.get('match_adapter_assist', False)
match_adapter_assist = kwargs.get('match_adapter_assist', False)
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
# legacy
if match_adapter_assist and self.match_adapter_chance == 0.0:
self.match_adapter_chance = 1.0