Added ability to add models to finetune as plugins. Also added flux2 new arch via that method.

This commit is contained in:
Jaret Burkett
2025-03-27 16:07:00 -06:00
parent e9e30104d3
commit 5365200da1
12 changed files with 936 additions and 1058 deletions

View File

@@ -634,11 +634,8 @@ class CustomAdapter(torch.nn.Module):
latents = torch.cat((latents, control_latent), dim=1)
return latents.detach()
control_tensor = batch.control_tensor.to(latents.device, dtype=latents.dtype)
if control_tensor is None:
# concat random normal noise onto the latents
# check dimension, this is before they are rearranged
# it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging
# concat zeros onto the latents
ctrl = torch.zeros(
latents.shape[0], # bs
latents.shape[1] * self.num_control_images, # ch
@@ -656,6 +653,8 @@ class CustomAdapter(torch.nn.Module):
# if we have 1, it comes in like [bs, ch, h, w]
# stack out control tensors to be [bs, ch * num_control_images, h, w]
control_tensor = batch.control_tensor.to(latents.device, dtype=latents.dtype)
control_tensor_list = []
if len(control_tensor.shape) == 4:
control_tensor_list.append(control_tensor)