mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Handle multi control inputs for control lora training
This commit is contained in:
@@ -82,7 +82,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
self.position_ids: Optional[List[int]] = None
|
||||
|
||||
self.num_control_images = 1
|
||||
self.num_control_images = self.config.num_control_images
|
||||
self.token_mask: Optional[torch.Tensor] = None
|
||||
|
||||
# setup clip
|
||||
@@ -575,19 +575,53 @@ class CustomAdapter(torch.nn.Module):
|
||||
# concat random normal noise onto the latents
|
||||
# check dimension, this is before they are rearranged
|
||||
# it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging
|
||||
latents = torch.cat((latents, torch.randn_like(latents)), dim=1)
|
||||
ctrl = torch.randn(
|
||||
latents.shape[0], # bs
|
||||
latents.shape[1] * self.num_control_images, # ch
|
||||
latents.shape[2],
|
||||
latents.shape[3],
|
||||
device=latents.device,
|
||||
dtype=latents.dtype
|
||||
)
|
||||
latents = torch.cat((latents, ctrl), dim=1)
|
||||
return latents.detach()
|
||||
# it is 0-1 need to convert to -1 to 1
|
||||
control_tensor = control_tensor * 2 - 1
|
||||
# if we have multiple control tensors, they come in like [bs, num_control_images, ch, h, w]
|
||||
# if we have 1, it comes in like [bs, ch, h, w]
|
||||
# stack out control tensors to be [bs, ch * num_control_images, h, w]
|
||||
|
||||
control_tensor_list = []
|
||||
if len(control_tensor.shape) == 4:
|
||||
control_tensor_list.append(control_tensor)
|
||||
else:
|
||||
# reshape
|
||||
control_tensor = control_tensor.view(
|
||||
control_tensor.shape[0],
|
||||
control_tensor.shape[1] * control_tensor.shape[2],
|
||||
control_tensor.shape[3],
|
||||
control_tensor.shape[4]
|
||||
)
|
||||
control_tensor_list = control_tensor.chunk(self.num_control_images, dim=1)
|
||||
control_latent_list = []
|
||||
for control_tensor in control_tensor_list:
|
||||
do_dropout = random.random() < self.config.control_image_dropout
|
||||
if do_dropout:
|
||||
# dropout with noise
|
||||
control_latent_list.append(torch.zeros_like(batch.latents))
|
||||
else:
|
||||
# it is 0-1 need to convert to -1 to 1
|
||||
control_tensor = control_tensor * 2 - 1
|
||||
|
||||
control_tensor = control_tensor.to(sd.vae_device_torch, dtype=sd.torch_dtype)
|
||||
|
||||
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
|
||||
if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]:
|
||||
control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bicubic')
|
||||
|
||||
# encode it
|
||||
control_latent = sd.encode_images(control_tensor).to(latents.device, latents.dtype)
|
||||
control_tensor = control_tensor.to(sd.vae_device_torch, dtype=sd.torch_dtype)
|
||||
|
||||
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
|
||||
if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]:
|
||||
control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bicubic')
|
||||
|
||||
# encode it
|
||||
control_latent = sd.encode_images(control_tensor).to(latents.device, latents.dtype)
|
||||
control_latent_list.append(control_latent)
|
||||
# stack them on the channel dimension
|
||||
control_latent = torch.cat(control_latent_list, dim=1)
|
||||
# concat it onto the latents
|
||||
latents = torch.cat((latents, control_latent), dim=1)
|
||||
return latents.detach()
|
||||
|
||||
Reference in New Issue
Block a user