Fixed issue with adapters that only had 1 input channel. Added ability to set the percentage chance of adapter matching

This commit is contained in:
Jaret Burkett
2023-10-15 15:13:35 -06:00
parent b1a22d0b3e
commit a05459afaf
2 changed files with 79 additions and 38 deletions

View File

@@ -22,6 +22,7 @@ def flush():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
adapter_transforms = transforms.Compose([ adapter_transforms = transforms.Compose([
# transforms.PILToTensor(), # transforms.PILToTensor(),
transforms.ToTensor(), transforms.ToTensor(),
@@ -51,7 +52,6 @@ class SDTrainer(BaseSDTrainProcess):
self.assistant_adapter.requires_grad_(False) self.assistant_adapter.requires_grad_(False)
flush() flush()
def hook_before_train_loop(self): def hook_before_train_loop(self):
# move vae to device if we did not cache latents # move vae to device if we did not cache latents
if not self.is_latents_cached: if not self.is_latents_cached:
@@ -62,7 +62,6 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.vae.to('cpu') self.sd.vae.to('cpu')
flush() flush()
def hook_train_loop(self, batch): def hook_train_loop(self, batch):
self.timer.start('preprocess_batch') self.timer.start('preprocess_batch')
@@ -72,6 +71,17 @@ class SDTrainer(BaseSDTrainProcess):
has_adapter_img = batch.control_tensor is not None has_adapter_img = batch.control_tensor is not None
match_adapter_assist = False
# check if we are matching the adapter assistant
if self.assistant_adapter:
if self.train_config.match_adapter_chance == 1.0:
match_adapter_assist = True
elif self.train_config.match_adapter_chance > 0.0:
match_adapter_assist = torch.rand(
(1,), device=self.device_torch, dtype=dtype
) < self.train_config.match_adapter_chance
self.timer.stop('preprocess_batch') self.timer.stop('preprocess_batch')
with torch.no_grad(): with torch.no_grad():
@@ -82,6 +92,12 @@ class SDTrainer(BaseSDTrainProcess):
# todo move this to data loader # todo move this to data loader
if batch.control_tensor is not None: if batch.control_tensor is not None:
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach() adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
# match in channels
if self.assistant_adapter is not None:
in_channels = self.assistant_adapter.config.in_channels
if adapter_images.shape[1] != in_channels:
# we need to match the channels
adapter_images = adapter_images[:, :in_channels, :, :]
else: else:
raise NotImplementedError("Adapter images now must be loaded with dataloader") raise NotImplementedError("Adapter images now must be loaded with dataloader")
# not 100% sure what this does. But they do it here # not 100% sure what this does. But they do it here
@@ -106,7 +122,7 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter and isinstance(self.adapter, T2IAdapter): if self.adapter and isinstance(self.adapter, T2IAdapter):
# training a t2i adapter, not using as assistant. # training a t2i adapter, not using as assistant.
return 1.0 return 1.0
elif self.train_config.match_adapter_assist: elif match_adapter_assist:
# training a texture. We want it high # training a texture. We want it high
adapter_strength_min = 0.9 adapter_strength_min = 0.9
adapter_strength_max = 1.0 adapter_strength_max = 1.0
@@ -117,18 +133,18 @@ class SDTrainer(BaseSDTrainProcess):
adapter_strength_min = 0.9 adapter_strength_min = 0.9
adapter_strength_max = 1.1 adapter_strength_max = 1.1
adapter_conditioning_scale = torch.rand( adapter_conditioning_scale = torch.rand(
(1,), device=self.device_torch, dtype=dtype (1,), device=self.device_torch, dtype=dtype
) )
adapter_conditioning_scale = value_map( adapter_conditioning_scale = value_map(
adapter_conditioning_scale, adapter_conditioning_scale,
0.0, 0.0,
1.0, 1.0,
adapter_strength_min, adapter_strength_min,
adapter_strength_max adapter_strength_max
) )
return adapter_conditioning_scale return adapter_conditioning_scale
# flush() # flush()
with self.timer('grad_setup'): with self.timer('grad_setup'):
@@ -154,11 +170,26 @@ class SDTrainer(BaseSDTrainProcess):
# activate network if it exits # activate network if it exits
with network: with network:
with self.timer('encode_prompt'): with self.timer('encode_prompt'):
with torch.set_grad_enabled(grad_on_text_encoder): if grad_on_text_encoder:
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype) with torch.set_grad_enabled(True):
if not grad_on_text_encoder: conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(
self.device_torch,
dtype=dtype)
else:
with torch.set_grad_enabled(False):
# make sure it is in eval mode
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.eval()
else:
self.sd.text_encoder.eval()
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(
self.device_torch,
dtype=dtype)
# detach the embeddings # detach the embeddings
conditional_embeds = conditional_embeds.detach() conditional_embeds = conditional_embeds.detach()
# flush() # flush()
pred_kwargs = {} pred_kwargs = {}
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter): if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
@@ -170,33 +201,38 @@ class SDTrainer(BaseSDTrainProcess):
if self.assistant_adapter: if self.assistant_adapter:
# not training. detach # not training. detach
down_block_additional_residuals = [ down_block_additional_residuals = [
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in down_block_additional_residuals sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
down_block_additional_residuals
] ]
else: else:
down_block_additional_residuals = [ down_block_additional_residuals = [
sample.to(dtype=dtype) * adapter_multiplier for sample in down_block_additional_residuals sample.to(dtype=dtype) * adapter_multiplier for sample in
down_block_additional_residuals
] ]
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
control_pred = None control_pred = None
if has_adapter_img and self.assistant_adapter and self.train_config.match_adapter_assist: if has_adapter_img and self.assistant_adapter and match_adapter_assist:
# do a prediction here so we can match its output with network multiplier set to 0.0 with self.timer('predict_with_adapter'):
with torch.no_grad(): # do a prediction here so we can match its output with network multiplier set to 0.0
# dont use network on this with torch.no_grad():
network.multiplier = 0.0 # dont use network on this
control_pred = self.sd.predict_noise( network.multiplier = 0.0
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), self.sd.unet.eval()
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), control_pred = self.sd.predict_noise(
timestep=timesteps, latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
guidance_scale=1.0, conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
**pred_kwargs # adapter residuals in here timestep=timesteps,
) guidance_scale=1.0,
control_pred = control_pred.detach() **pred_kwargs # adapter residuals in here
# remove the residuals as we wont use them on prediction when matching control )
del pred_kwargs['down_block_additional_residuals'] self.sd.unet.train()
# restore network control_pred = control_pred.detach()
network.multiplier = network_weight_list # remove the residuals as we wont use them on prediction when matching control
del pred_kwargs['down_block_additional_residuals']
# restore network
network.multiplier = network_weight_list
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter): if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter'): with self.timer('encode_adapter'):
@@ -204,7 +240,6 @@ class SDTrainer(BaseSDTrainProcess):
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images) conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
with self.timer('predict_unet'): with self.timer('predict_unet'):
noise_pred = self.sd.predict_noise( noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype), latents=noisy_latents.to(self.device_torch, dtype=dtype),

View File

@@ -122,7 +122,13 @@ class TrainConfig:
self.start_step = kwargs.get('start_step', None) self.start_step = kwargs.get('start_step', None)
self.free_u = kwargs.get('free_u', False) self.free_u = kwargs.get('free_u', False)
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None) self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
self.match_adapter_assist = kwargs.get('match_adapter_assist', False)
match_adapter_assist = kwargs.get('match_adapter_assist', False)
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
# legacy
if match_adapter_assist and self.match_adapter_chance == 0.0:
self.match_adapter_chance = 1.0