Handle multi control inputs for control lora training

This commit is contained in:
Jaret Burkett
2025-03-23 07:37:08 -06:00
parent ccb66c748f
commit f10937e6da
7 changed files with 446 additions and 75 deletions

View File

@@ -82,7 +82,7 @@ class CustomAdapter(torch.nn.Module):
self.position_ids: Optional[List[int]] = None
self.num_control_images = 1
self.num_control_images = self.config.num_control_images
self.token_mask: Optional[torch.Tensor] = None
# setup clip
@@ -575,19 +575,53 @@ class CustomAdapter(torch.nn.Module):
# concat random normal noise onto the latents
# check dimension, this is before they are rearranged
# it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging
latents = torch.cat((latents, torch.randn_like(latents)), dim=1)
ctrl = torch.randn(
latents.shape[0], # bs
latents.shape[1] * self.num_control_images, # ch
latents.shape[2],
latents.shape[3],
device=latents.device,
dtype=latents.dtype
)
latents = torch.cat((latents, ctrl), dim=1)
return latents.detach()
# it is 0-1 need to convert to -1 to 1
control_tensor = control_tensor * 2 - 1
# if we have multiple control tensors, they come in like [bs, num_control_images, ch, h, w]
# if we have 1, it comes in like [bs, ch, h, w]
# stack out control tensors to be [bs, ch * num_control_images, h, w]
control_tensor_list = []
if len(control_tensor.shape) == 4:
control_tensor_list.append(control_tensor)
else:
# reshape
control_tensor = control_tensor.view(
control_tensor.shape[0],
control_tensor.shape[1] * control_tensor.shape[2],
control_tensor.shape[3],
control_tensor.shape[4]
)
control_tensor_list = control_tensor.chunk(self.num_control_images, dim=1)
control_latent_list = []
for control_tensor in control_tensor_list:
do_dropout = random.random() < self.config.control_image_dropout
if do_dropout:
# dropout with noise
control_latent_list.append(torch.zeros_like(batch.latents))
else:
# it is 0-1 need to convert to -1 to 1
control_tensor = control_tensor * 2 - 1
control_tensor = control_tensor.to(sd.vae_device_torch, dtype=sd.torch_dtype)
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]:
control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bicubic')
# encode it
control_latent = sd.encode_images(control_tensor).to(latents.device, latents.dtype)
control_tensor = control_tensor.to(sd.vae_device_torch, dtype=sd.torch_dtype)
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]:
control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bicubic')
# encode it
control_latent = sd.encode_images(control_tensor).to(latents.device, latents.dtype)
control_latent_list.append(control_latent)
# stack them on the channel dimension
control_latent = torch.cat(control_latent_list, dim=1)
# concat it onto the latents
latents = torch.cat((latents, control_latent), dim=1)
return latents.detach()