mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added ability to add models to finetune as plugins. Also added flux2 new arch via that method.
This commit is contained in:
@@ -634,11 +634,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
latents = torch.cat((latents, control_latent), dim=1)
|
||||
return latents.detach()
|
||||
|
||||
control_tensor = batch.control_tensor.to(latents.device, dtype=latents.dtype)
|
||||
if control_tensor is None:
|
||||
# concat random normal noise onto the latents
|
||||
# check dimension, this is before they are rearranged
|
||||
# it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging
|
||||
# concat zeros onto the latents
|
||||
ctrl = torch.zeros(
|
||||
latents.shape[0], # bs
|
||||
latents.shape[1] * self.num_control_images, # ch
|
||||
@@ -656,6 +653,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
# if we have 1, it comes in like [bs, ch, h, w]
|
||||
# stack out control tensors to be [bs, ch * num_control_images, h, w]
|
||||
|
||||
control_tensor = batch.control_tensor.to(latents.device, dtype=latents.dtype)
|
||||
|
||||
control_tensor_list = []
|
||||
if len(control_tensor.shape) == 4:
|
||||
control_tensor_list.append(control_tensor)
|
||||
|
||||
Reference in New Issue
Block a user