mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-21 04:43:58 +00:00
Fixed issue with adapters that only had 1 input channel. Added ability to set the percentage chance of adapter matching
This commit is contained in:
@@ -22,6 +22,7 @@ def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
adapter_transforms = transforms.Compose([
|
||||
# transforms.PILToTensor(),
|
||||
transforms.ToTensor(),
|
||||
@@ -51,7 +52,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.assistant_adapter.requires_grad_(False)
|
||||
flush()
|
||||
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
# move vae to device if we did not cache latents
|
||||
if not self.is_latents_cached:
|
||||
@@ -62,7 +62,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.sd.vae.to('cpu')
|
||||
flush()
|
||||
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
|
||||
self.timer.start('preprocess_batch')
|
||||
@@ -72,6 +71,17 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
has_adapter_img = batch.control_tensor is not None
|
||||
|
||||
match_adapter_assist = False
|
||||
|
||||
# check if we are matching the adapter assistant
|
||||
if self.assistant_adapter:
|
||||
if self.train_config.match_adapter_chance == 1.0:
|
||||
match_adapter_assist = True
|
||||
elif self.train_config.match_adapter_chance > 0.0:
|
||||
match_adapter_assist = torch.rand(
|
||||
(1,), device=self.device_torch, dtype=dtype
|
||||
) < self.train_config.match_adapter_chance
|
||||
|
||||
self.timer.stop('preprocess_batch')
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -82,6 +92,12 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# todo move this to data loader
|
||||
if batch.control_tensor is not None:
|
||||
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
|
||||
# match in channels
|
||||
if self.assistant_adapter is not None:
|
||||
in_channels = self.assistant_adapter.config.in_channels
|
||||
if adapter_images.shape[1] != in_channels:
|
||||
# we need to match the channels
|
||||
adapter_images = adapter_images[:, :in_channels, :, :]
|
||||
else:
|
||||
raise NotImplementedError("Adapter images now must be loaded with dataloader")
|
||||
# not 100% sure what this does. But they do it here
|
||||
@@ -106,7 +122,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.adapter and isinstance(self.adapter, T2IAdapter):
|
||||
# training a t2i adapter, not using as assistant.
|
||||
return 1.0
|
||||
elif self.train_config.match_adapter_assist:
|
||||
elif match_adapter_assist:
|
||||
# training a texture. We want it high
|
||||
adapter_strength_min = 0.9
|
||||
adapter_strength_max = 1.0
|
||||
@@ -117,18 +133,18 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
adapter_strength_min = 0.9
|
||||
adapter_strength_max = 1.1
|
||||
|
||||
adapter_conditioning_scale = torch.rand(
|
||||
(1,), device=self.device_torch, dtype=dtype
|
||||
)
|
||||
adapter_conditioning_scale = torch.rand(
|
||||
(1,), device=self.device_torch, dtype=dtype
|
||||
)
|
||||
|
||||
adapter_conditioning_scale = value_map(
|
||||
adapter_conditioning_scale,
|
||||
0.0,
|
||||
1.0,
|
||||
adapter_strength_min,
|
||||
adapter_strength_max
|
||||
)
|
||||
return adapter_conditioning_scale
|
||||
adapter_conditioning_scale = value_map(
|
||||
adapter_conditioning_scale,
|
||||
0.0,
|
||||
1.0,
|
||||
adapter_strength_min,
|
||||
adapter_strength_max
|
||||
)
|
||||
return adapter_conditioning_scale
|
||||
|
||||
# flush()
|
||||
with self.timer('grad_setup'):
|
||||
@@ -154,11 +170,26 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# activate network if it exits
|
||||
with network:
|
||||
with self.timer('encode_prompt'):
|
||||
with torch.set_grad_enabled(grad_on_text_encoder):
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype)
|
||||
if not grad_on_text_encoder:
|
||||
if grad_on_text_encoder:
|
||||
with torch.set_grad_enabled(True):
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
else:
|
||||
with torch.set_grad_enabled(False):
|
||||
# make sure it is in eval mode
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.eval()
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
|
||||
@@ -170,33 +201,38 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.assistant_adapter:
|
||||
# not training. detach
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in down_block_additional_residuals
|
||||
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
else:
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) * adapter_multiplier for sample in down_block_additional_residuals
|
||||
sample.to(dtype=dtype) * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
control_pred = None
|
||||
if has_adapter_img and self.assistant_adapter and self.train_config.match_adapter_assist:
|
||||
# do a prediction here so we can match its output with network multiplier set to 0.0
|
||||
with torch.no_grad():
|
||||
# dont use network on this
|
||||
network.multiplier = 0.0
|
||||
control_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
control_pred = control_pred.detach()
|
||||
# remove the residuals as we wont use them on prediction when matching control
|
||||
del pred_kwargs['down_block_additional_residuals']
|
||||
# restore network
|
||||
network.multiplier = network_weight_list
|
||||
if has_adapter_img and self.assistant_adapter and match_adapter_assist:
|
||||
with self.timer('predict_with_adapter'):
|
||||
# do a prediction here so we can match its output with network multiplier set to 0.0
|
||||
with torch.no_grad():
|
||||
# dont use network on this
|
||||
network.multiplier = 0.0
|
||||
self.sd.unet.eval()
|
||||
control_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
self.sd.unet.train()
|
||||
control_pred = control_pred.detach()
|
||||
# remove the residuals as we wont use them on prediction when matching control
|
||||
del pred_kwargs['down_block_additional_residuals']
|
||||
# restore network
|
||||
network.multiplier = network_weight_list
|
||||
|
||||
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter'):
|
||||
@@ -204,7 +240,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
|
||||
|
||||
with self.timer('predict_unet'):
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
|
||||
@@ -122,7 +122,13 @@ class TrainConfig:
|
||||
self.start_step = kwargs.get('start_step', None)
|
||||
self.free_u = kwargs.get('free_u', False)
|
||||
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
||||
self.match_adapter_assist = kwargs.get('match_adapter_assist', False)
|
||||
|
||||
match_adapter_assist = kwargs.get('match_adapter_assist', False)
|
||||
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
|
||||
|
||||
# legacy
|
||||
if match_adapter_assist and self.match_adapter_chance == 0.0:
|
||||
self.match_adapter_chance = 1.0
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user