Fix issue where flux2 would ignore single control image on training

This commit is contained in:
Jaret Burkett
2026-01-17 20:26:35 +00:00
parent e132dbae76
commit 0efed794b4

View File

@@ -312,7 +312,13 @@ class Flux2Model(BaseModel):
img_cond_seq_ids: torch.Tensor | None = None
# handle control images
if batch.control_tensor_list is not None:
batch_control_tensor_list = batch.control_tensor_list
if batch_control_tensor_list is None and batch.control_tensor is not None:
batch_control_tensor_list = []
for b in range(latent_model_input.shape[0]):
batch_control_tensor_list.append(batch.control_tensor[b : b + 1])
if batch_control_tensor_list is not None:
batch_size, num_channels_latents, height, width = (
latent_model_input.shape
)
@@ -328,11 +334,11 @@ class Flux2Model(BaseModel):
)
control_image_max_res = control_image_res
if len(batch.control_tensor_list) != batch_size:
if len(batch_control_tensor_list) != batch_size:
raise ValueError(
"Control tensor list length does not match batch size"
)
for control_tensor_list in batch.control_tensor_list:
for control_tensor_list in batch_control_tensor_list:
# control tensor list is a list of tensors for this batch item
controls = []
# pack control